Source code for grafei.modules.GraFEIModule

##########################################################################
# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see LICENSE.md.                  #
##########################################################################


import itertools
import numpy as np
import yaml
import warnings
import basf2 as b2
from ROOT import Belle2
from variables import variables as vm
import torch
from torch_geometric.data import Batch
from grafei.modules.LCASaverModule import get_object_list, write_hist
from grafei.model.geometric_network import GraFEIModel
from grafei.model.normalize_features import normalize_features
from grafei.model.edge_features import compute_edge_features
from grafei.model.lca_to_adjacency import lca_to_adjacency, InvalidLCAMatrix, select_good_decay
from grafei.model.tree_utils import masses_to_classes

warnings.filterwarnings(
    action="ignore", category=RuntimeWarning, message="Mean of empty slice.*"
)


[docs]class GraFEIModule(b2.Module): """ Applies graFEI model to a particle list in basf2. GraFEI information is stored as extraInfos. Args: particle_list (str): Name of particle list. cfg_path (str): Path to config file. If `None` the config file in the global tag is used. param_file (str): Path to parameter file containing the model. If `None` the parameter file in the global tag is used. sig_side_lcas (list): List containing LCAS matrix of signal-side. sig_side_masses (list): List containing mass hypotheses of signal-side. gpu (bool): Whether to run on a GPU. payload_config_name (str): Name of config file payload. The default should be kept, except in basf2 examples. payload_model_name (str): Name of model file payload. The default should be kept, except in basf2 examples. """ def __init__( self, particle_list, cfg_path=None, param_file=None, sig_side_lcas=None, sig_side_masses=None, gpu=False, payload_config_name="graFEIConfigFile", payload_model_name="graFEIModelFile", ): """ Initialization. """ super().__init__() #: Input particle list self.particle_list = particle_list #: Config yaml file path self.cfg_path = cfg_path #: PyTorch parameter file path self.param_file = param_file #: Chosen sig-side LCAS self.sig_side_lcas = torch.tensor(sig_side_lcas) if sig_side_lcas else None #: Chosen sig-side mass hypotheses self.sig_side_masses = sig_side_masses #: If running on GPU self.gpu = gpu #: Config file name in the payload self.payload_config_name = payload_config_name #: Model file name in the payload self.payload_model_name = payload_model_name def initialize(self): """ Called at the beginning. """ # Get weights and configs from the DB if they are not provided from the user if not self.cfg_path: config = Belle2.DBAccessorBase( Belle2.DBStoreEntry.c_RawFile, self.payload_config_name, True ) self.cfg_path = config.getFilename() if not self.param_file: model = Belle2.DBAccessorBase( Belle2.DBStoreEntry.c_RawFile, self.payload_model_name, True ) self.param_file = model.getFilename() #: Figure out if we re running on data or MC self.storeTrueInfo = Belle2.Environment.Instance().isMC() #: Figure out which device all this is running on - CPU or GPU self.device = torch.device( "cuda" if (self.gpu and torch.cuda.is_available()) else "cpu" ) # Load configs cfg_file = open(self.cfg_path, "r") #: Config file self.configs = yaml.safe_load(cfg_file) #: Top MC particle self.mc_particle = None #: Max LCAS level self.max_level = None # B or Ups reco? 0 = Ups, 1 = B0, 2 = B+ if self.configs["model"]["B_reco"] == 0: self.mc_particle = "Upsilon(4S):MC" self.max_level = 6 elif self.configs["model"]["B_reco"] == 1: self.mc_particle = "B0:MC" self.max_level = 5 elif self.configs["model"]["B_reco"] == 2: self.mc_particle = "B+:MC" self.max_level = 5 else: b2.B2FATAL("The B_reco setting in the config file is incorrect.") #: Normalize features self.normalize = self.configs["dataset"]["config"]["normalize"] #: Mixed precision self.use_amp = self.configs["train"][ "mixed_precision" ] and self.device == torch.device("cuda") #: Node features self.node_features = self.configs["dataset"]["config"]["features"] #: Edge features self.edge_features = self.configs["dataset"]["config"]["edge_features"] #: Global features self.glob_features = self.configs["dataset"]["config"]["global_features"] # Naming convention self.node_features = [f"feat_{name}" for name in self.node_features] if self.node_features else [] self.edge_features = [f"edge_{name}" for name in self.edge_features] if self.edge_features else [] self.glob_features = [f"glob_{name}" for name in self.glob_features] if self.glob_features else [] #: Discarded node features self.discarded_features = ["feat_x", "feat_y", "feat_z", "feat_px", "feat_py", "feat_p"] # Extract the number of features n_infeatures = len(self.node_features) e_infeatures = len(self.edge_features) g_infeatures = len(self.glob_features) #: The model # The correct edge_classes is taken from the config file self.model = GraFEIModel( nfeat_in_dim=n_infeatures, efeat_in_dim=e_infeatures, gfeat_in_dim=g_infeatures, **self.configs["model"], ) # Load paramaters' values self.model.load_state_dict( torch.load(self.param_file, map_location=self.device)["model"] ) # Activate evaluation mode self.model.eval() # Push model to GPU in case self.model.to(self.device) b2.B2DEBUG(10, "Model structure:\n", {self.model}) def event(self): """ Called at the beginning of each event. """ b2.B2DEBUG(10, "---- Processing new event ----") # Get the B candidate list candidate_list = get_object_list(Belle2.PyStoreObj(self.particle_list).obj()) # Get the particle candidate(s) for candidate in candidate_list: # Get FSPs p_list = get_object_list(candidate.getFinalStateDaughters()) # Number of FSPs n_nodes = len(p_list) # Particle nature masses = np.array([abs(p.getPDGCode()) for p in p_list]) # Number of charged and photons graFEI_nFSP = n_nodes graFEI_nPhotons_preFit = (masses == 22).sum() graFEI_nCharged_preFit = graFEI_nFSP - graFEI_nPhotons_preFit graFEI_nElectrons_preFit = (masses == 11).sum() graFEI_nMuons_preFit = (masses == 13).sum() graFEI_nPions_preFit = (masses == 211).sum() graFEI_nKaons_preFit = (masses == 321).sum() graFEI_nProtons_preFit = (masses == 2212).sum() graFEI_nLeptons_preFit = graFEI_nElectrons_preFit + graFEI_nMuons_preFit graFEI_nOthers_preFit = graFEI_nCharged_preFit - \ (graFEI_nLeptons_preFit + graFEI_nPions_preFit + graFEI_nKaons_preFit + graFEI_nProtons_preFit) candidate.addExtraInfo("graFEI_nFSP", graFEI_nFSP) candidate.addExtraInfo("graFEI_nCharged_preFit", graFEI_nCharged_preFit) candidate.addExtraInfo("graFEI_nPhotons_preFit", graFEI_nPhotons_preFit) candidate.addExtraInfo("graFEI_nElectrons_preFit", graFEI_nElectrons_preFit) candidate.addExtraInfo("graFEI_nMuons_preFit", graFEI_nMuons_preFit) candidate.addExtraInfo("graFEI_nPions_preFit", graFEI_nPions_preFit) candidate.addExtraInfo("graFEI_nKaons_preFit", graFEI_nKaons_preFit) candidate.addExtraInfo("graFEI_nProtons_preFit", graFEI_nProtons_preFit) candidate.addExtraInfo("graFEI_nLeptons_preFit", graFEI_nLeptons_preFit) candidate.addExtraInfo("graFEI_nOthers_preFit", graFEI_nOthers_preFit) # Trivial decay tree if n_nodes < 2: b2.B2WARNING( f"Skipping candidate with {n_nodes} reconstructed FSPs" ) continue # Initialize node features array x_nodes = np.empty((n_nodes, len(self.node_features))) x_dis = np.empty((n_nodes, len(self.discarded_features))) # Fill node features array for p, particle in enumerate(p_list): for f, feat in enumerate(self.node_features): feat = feat[feat.find("feat_") + 5:] x_nodes[p, f] = vm.evaluate(feat, particle) for f, feat in enumerate(self.discarded_features): feat = feat[feat.find("feat_") + 5:] x_dis[p, f] = vm.evaluate(feat, particle) b2.B2DEBUG(11, "Node features:\n", x_nodes) # Fill edge features array x_edges = (compute_edge_features(self.edge_features, self.node_features + self.discarded_features, np.concatenate([x_nodes, x_dis], axis=1)) if self.edge_features != [] else []) edge_index = torch.tensor(list(itertools.permutations(range(n_nodes), 2)), dtype=torch.long) b2.B2DEBUG(11, "Edge features:\n", x_edges) # Fill global features # TODO: get them from basf2 x_global = ( np.array([[n_nodes]], dtype=float) if self.glob_features != [] else [] ) b2.B2DEBUG(11, "Global features:\n", x_global) # Fill tensor to assign each node to a graph (trivial since we have only one graph per decay) torch_batch = torch.zeros(size=[n_nodes], dtype=torch.long) # Set nans to zero, this is a surrogate value, may change in future np.nan_to_num(x_nodes, copy=False) np.nan_to_num(x_edges, copy=False) np.nan_to_num(x_global, copy=False) # Normalize any features that should be if self.normalize is not None: normalize_features( self.normalize, self.node_features, x_nodes, self.edge_features, x_edges, self.glob_features, x_global, ) # Convert everything to torch tensors and/or send to some device in case x = torch.tensor(x_nodes, dtype=torch.float).to(self.device) edge_index = edge_index.t().contiguous().to(self.device) edge_attr = torch.tensor(x_edges, dtype=torch.float).to(self.device) u = torch.tensor(x_global, dtype=torch.float).to(self.device) torch_batch = torch_batch.to(self.device) # Create Batch object to be passed to model batch = Batch( x=x, edge_index=edge_index, edge_attr=edge_attr, u=u, batch=torch_batch ) # Evaluate model with torch.no_grad(): x_pred, e_pred, u_pred = self.model(batch) # if self.use_amp: # with autocast(enabled=True): # x_pred, e_pred, u_pred = self.model(batch) # else: # x_pred, e_pred, u_pred = self.model(batch) # Select edges from predictions edge_probs = torch.softmax(e_pred, dim=1) edge_probability, predicted_LCA = edge_probs.max(dim=1) # Select masses from predictions mass_probs = torch.softmax(x_pred, dim=1) mass_probability, predicted_masses = mass_probs.max(dim=1) b2.B2DEBUG(10, "Predicted mass classes:\n", predicted_masses) b2.B2DEBUG(11, "Mass class probabilities:\n", mass_probability) # Count number of predicted particles for each mass hypothesis graFEI_nPhotons_postFit = (predicted_masses == 6).sum() graFEI_nCharged_postFit = graFEI_nFSP - graFEI_nPhotons_postFit graFEI_nElectrons_postFit = (predicted_masses == 1).sum() graFEI_nMuons_postFit = (predicted_masses == 2).sum() graFEI_nPions_postFit = (predicted_masses == 3).sum() graFEI_nKaons_postFit = (predicted_masses == 4).sum() graFEI_nProtons_postFit = (predicted_masses == 5).sum() graFEI_nLeptons_postFit = graFEI_nElectrons_postFit + graFEI_nMuons_postFit graFEI_nOthers_postFit = (predicted_masses == 0).sum() # Get square matrices edge_probability_square = torch.sparse_coo_tensor( edge_index, edge_probability ).to_dense() predicted_LCA_square = torch.sparse_coo_tensor( edge_index, predicted_LCA, dtype=int ).to_dense() b2.B2DEBUG(10, "Predicted LCA:\n", predicted_LCA_square) b2.B2DEBUG(11, "Edge class probabilities:\n", edge_probability_square) # Remove symmetric elements from probability edge_probability_unique = edge_probability_square[ edge_probability_square.tril(diagonal=-1) > 0 ] # Get particles predicted as matched by the model predicted_matched = np.array( [False if torch.all(i == 0) else True for i in predicted_LCA_square] ) b2.B2DEBUG(10, "Predicted matched particles:\n", predicted_matched) # Same but ignoring photons predicted_matched_noPhotons = predicted_matched[masses != 22] # Get number of predicted as unmatched graFEI_nPredictedUnmatched = (~predicted_matched).sum() graFEI_nPredictedUnmatched_noPhotons = ( (~predicted_matched_noPhotons).sum() if predicted_matched_noPhotons.size != 0 else 0 ) # Get LCA of predicted matched only predicted_LCA_square_matched = predicted_LCA_square[predicted_matched] predicted_LCA_square_matched = predicted_LCA_square_matched[:, predicted_matched] # Get predicted masses of predicted matched only predicted_masses_matched = predicted_masses[predicted_matched] # Check if LCA describes a tree graph graFEI_validTree = 0 if not torch.all(predicted_LCA_square == 0): try: adjacency = lca_to_adjacency(predicted_LCA_square_matched) graFEI_validTree = 1 except InvalidLCAMatrix: pass # Check if event is good, depending on the chosen sig-side LCA matrix/masses graFEI_goodEvent = 0 if graFEI_validTree: # Check if the event is good good_decay, root_level, sig_side_fsps = select_good_decay(predicted_LCA_square_matched, predicted_masses_matched, self.sig_side_lcas, self.sig_side_masses) graFEI_goodEvent = int((self.max_level == root_level) and good_decay) if graFEI_goodEvent: # Find sig- and tag-side FSPs (1 = sig-side, 0 = tag-side) p_list_matched = list(np.array(p_list)[predicted_matched]) for i, particle in enumerate(p_list_matched): if i in sig_side_fsps: particle.addExtraInfo("graFEI_sigSide", 1) else: particle.addExtraInfo("graFEI_sigSide", 0) b2.B2DEBUG(11, "This LCA describes a valid tree") b2.B2DEBUG( 11, "Predicted LCA on matched particles:\n", predicted_LCA_square_matched, ) b2.B2DEBUG(11, "Adjacency matrix:\n", adjacency) # Particles not assigned to B decays get -1 for particle in p_list: if not particle.hasExtraInfo("graFEI_sigSide"): particle.addExtraInfo("graFEI_sigSide", -1) # Define B probabilities graFEI_probEdgeProd = edge_probability_unique.prod().item() graFEI_probEdgeMean = edge_probability_unique.mean().item() graFEI_probEdgeGeom = torch.pow(edge_probability_unique.prod(), 1/n_nodes).item() # Add extra info for each B candidate candidate.addExtraInfo("graFEI_probEdgeProd", graFEI_probEdgeProd) candidate.addExtraInfo("graFEI_probEdgeMean", graFEI_probEdgeMean) candidate.addExtraInfo("graFEI_probEdgeGeom", graFEI_probEdgeGeom) candidate.addExtraInfo("graFEI_validTree", graFEI_validTree) candidate.addExtraInfo("graFEI_goodEvent", graFEI_goodEvent) candidate.addExtraInfo("graFEI_nPhotons_postFit", graFEI_nPhotons_postFit) candidate.addExtraInfo("graFEI_nCharged_postFit", graFEI_nCharged_postFit) candidate.addExtraInfo("graFEI_nElectrons_postFit", graFEI_nElectrons_postFit) candidate.addExtraInfo("graFEI_nMuons_postFit", graFEI_nMuons_postFit) candidate.addExtraInfo("graFEI_nPions_postFit", graFEI_nPions_postFit) candidate.addExtraInfo("graFEI_nKaons_postFit", graFEI_nKaons_postFit) candidate.addExtraInfo("graFEI_nProtons_postFit", graFEI_nProtons_postFit) candidate.addExtraInfo("graFEI_nLeptons_postFit", graFEI_nLeptons_postFit) candidate.addExtraInfo("graFEI_nOthers_postFit", graFEI_nOthers_postFit) candidate.addExtraInfo("graFEI_nPredictedUnmatched", graFEI_nPredictedUnmatched) candidate.addExtraInfo("graFEI_nPredictedUnmatched_noPhotons", graFEI_nPredictedUnmatched_noPhotons) # Add MC truth information if self.storeTrueInfo: # Get the true IDs of the ancestors (if it's a B) parentID = np.array([vm.evaluate("ancestorBIndex", p) for p in p_list], dtype=int) b2.B2DEBUG(10, "Ancestor true ID:\n", parentID) # Get particle indices p_indices = np.array( [ p.getMCParticle().getArrayIndex() if parentID[i] >= 0 else -1 for (i, p) in enumerate(p_list) ] ) # Get particle masses p_masses = masses_to_classes( np.array( [ p.getMCParticle().getPDG() if parentID[i] >= 0 else -1 for (i, p) in enumerate(p_list) ] ) ) b2.B2DEBUG(10, "True mass classes:\n", p_masses) # And primary information evt_primary = np.array( [ p.getMCParticle().isPrimaryParticle() if parentID[i] >= 0 else False for (i, p) in enumerate(p_list) ] ) b2.B2DEBUG(10, "Is primary particle:\n", evt_primary) # Get unique B indices associated to each predicted matched particle which is also a primary # The idea is that if a primary particle coming from the other B is categorized as unmatched, # then it's ok and the decay could still have a perfectLCA B_indices = parentID[np.logical_and(evt_primary, predicted_matched)] b2.B2DEBUG( 10, "Ancestor ID of predicted matched particles:\n", B_indices ) B_indices = list(set(B_indices)) # Initialize truth-matching variables graFEI_truth_perfectLCA = 0 # 1 if LCA perfectly reconstructed graFEI_truth_isSemileptonic = -1 # 0 if hadronic, 1 is semileptonic, -1 if not matched graFEI_truth_nFSP = -1 # Number of true FSPs graFEI_truth_perfectMasses = int((predicted_masses.numpy() == p_masses).all() ) # Check if all the masses are predicted correctly graFEI_truth_nPhotons = (p_masses == 6).sum() graFEI_truth_nElectrons = (p_masses == 1).sum() graFEI_truth_nMuons = (p_masses == 2).sum() graFEI_truth_nPions = (p_masses == 3).sum() graFEI_truth_nKaons = (p_masses == 4).sum() graFEI_truth_nProtons = (p_masses == 5).sum() graFEI_truth_nOthers = (p_masses == 0).sum() # Get the generated B's gen_list = Belle2.PyStoreObj(self.mc_particle) # Iterate over generated Ups if self.mc_particle == "Upsilon(4S):MC" and gen_list.getListSize() > 1: b2.B2WARNING( f"Found {gen_list.getListSize()} true Upsilon(4S) in the generated MC (??)") if gen_list.getListSize() > 0: # Here we look if the candidate has a perfectly reconstructed LCA for genP in gen_list.obj(): mcp = genP.getMCParticle() # If storing true info on B decays and we have matched paricles coming # from different Bs the decay will not have a perfectLCA if self.mc_particle != "Upsilon(4S):MC" and len(B_indices) != 1: break # Get array index of MC particle array_index = mcp.getArrayIndex() # If we are reconstructing Bs, skip the other in the event if self.mc_particle != "Upsilon(4S):MC" and array_index != B_indices[0]: continue # Write leaf history ( leaf_hist, levels, _, _, semilep_flag, ) = write_hist( particle=mcp, leaf_hist={}, levels={}, hist=[], pdg={}, leaf_pdg={}, semilep_flag=False, ) # Skip B decays with trivial LCA (should be always false except for B -> nunu ?) if len(leaf_hist) < 2: continue # Initialize LCA... true_LCA_square = np.zeros( [len(leaf_hist), len(leaf_hist)], dtype=int ) # Number of true FSPs graFEI_truth_nFSP = len(leaf_hist) # ... and fill it! for x, y in itertools.combinations(enumerate(leaf_hist), 2): intersection = [ i for i in leaf_hist[x[1]] if i in leaf_hist[y[1]] ] true_LCA_square[x[0], y[0]] = levels[intersection[-1]] true_LCA_square[y[0], x[0]] = levels[intersection[-1]] x_leaves = p_indices y_leaves = list(leaf_hist.keys()) # Get LCA indices in order that the leaves appear in reconstructed particles # Secondaries aren't in the LCA leaves list so they get a 0 locs = np.array( [ np.where(y_leaves == i)[0].item() if (i in y_leaves) else 0 for i in x_leaves ], dtype=int, ) # Insert dummy rows for secondaries true_LCA_square = true_LCA_square[locs, :][:, locs] # Set everything that's not primary (unmatched and secondaries) rows.cols to 0 # Note we only consider the subset of leaves that made it into x_rows x_rows = np.array( [ vm.evaluate("ancestorBIndex", p) == array_index for p in p_list ] ) if self.mc_particle != "Upsilon(4S):MC" else evt_primary primaries_from_right_cand = np.logical_and(evt_primary, x_rows) # Set the rows true_LCA_square = np.where( primaries_from_right_cand, true_LCA_square, 0 ) # Set the columns true_LCA_square = np.where( primaries_from_right_cand[:, None], true_LCA_square, 0 ) # Convert LCA to tensor true_LCA_square = torch.tensor(true_LCA_square, dtype=int) b2.B2DEBUG(10, "True LCA:\n", true_LCA_square) # Check if perfect LCA if (true_LCA_square == predicted_LCA_square).all(): graFEI_truth_perfectLCA = 1 b2.B2DEBUG(10, "LCA perfectly reconstructed!") # Assign semileptonic flag graFEI_truth_isSemileptonic = int(semilep_flag) # Perfect event = perfectLCA and perfectMasses graFEI_truth_perfectEvent = int(graFEI_truth_perfectLCA and graFEI_truth_perfectMasses) # Write extra info candidate.addExtraInfo("graFEI_truth_perfectLCA", graFEI_truth_perfectLCA) candidate.addExtraInfo("graFEI_truth_perfectMasses", graFEI_truth_perfectMasses) candidate.addExtraInfo("graFEI_truth_perfectEvent", graFEI_truth_perfectEvent) candidate.addExtraInfo("graFEI_truth_isSemileptonic", graFEI_truth_isSemileptonic) candidate.addExtraInfo("graFEI_truth_nFSP", graFEI_truth_nFSP) candidate.addExtraInfo("graFEI_truth_nPhotons", graFEI_truth_nPhotons) candidate.addExtraInfo("graFEI_truth_nElectrons", graFEI_truth_nElectrons) candidate.addExtraInfo("graFEI_truth_nMuons", graFEI_truth_nMuons) candidate.addExtraInfo("graFEI_truth_nPions", graFEI_truth_nPions) candidate.addExtraInfo("graFEI_truth_nKaons", graFEI_truth_nKaons) candidate.addExtraInfo("graFEI_truth_nProtons", graFEI_truth_nProtons) candidate.addExtraInfo("graFEI_truth_nOthers", graFEI_truth_nOthers)