Source code for grafei.model.lca_to_adjacency

##########################################################################
# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see LICENSE.md.                  #
##########################################################################


import torch as t
import numpy as np
from collections import Counter
from itertools import permutations
from .tree_utils import is_valid_tree


[docs]class InvalidLCAMatrix(Exception): """ Specialized Exception sub-class raised for malformed LCA matrices or LCA matrices not encoding trees. """ pass
[docs]class Node: """ Class to hold levels of nodes in the tree. Args: level (int): Level in the tree. children (list[Node]): Children of the nodes. lca_index (int): Index in the LCAS matrix. lcas_level (int): Level in the LCAS matrix. """ def __init__(self, level, children, lca_index=None, lcas_level=0): """ Initialization """ #: LCA level self.level = level #: Node children self.children = children #: LCA index self.lca_index = lca_index #: LCAS level self.lcas_level = lcas_level #: Parent nodes self.parent = None #: BFS index self.bfs_index = -1
def _get_ancestor(node): """ Trail search for the highest ancestor of a node. """ ancestor = node while ancestor.parent is not None: ancestor = ancestor.parent return ancestor def _nodes_in_ancestors_children(parent, node1, node2): """ Checks if any node in parent's line of descent is also an ancestor of both node1 and node2. """ for child in parent.children: if (node1 in child.children) and (node2 in child.children): return True else: _nodes_in_ancestors_children(child, node1, node2) return False def _pull_down(node): """ Works up the node's history, pulling down a level any nodes whose children are all more than one level below. Performs the operation in place. """ # First check the children if len(node.children) > 0: highest_child = max([c.level for c in node.children]) node.level = highest_child + 1 # Then move on to the parent if node.parent is not None: _pull_down(node.parent) return def _breadth_first_enumeration(root, queue, adjacency_matrix): """ Enumerates the tree breadth-first into a queue. """ # Insert current root node into the queue level = root.level queue.setdefault(level, []).append(root) # Enumerate the children for child in root.children: _breadth_first_enumeration(child, queue, adjacency_matrix) return queue def _breadth_first_adjacency(root, adjacency_matrix): """ Enumerates the tree breadth-first into a queue. """ queue = _breadth_first_enumeration(root, {}, adjacency_matrix) # On recursion end in the root node, traverse the tree once to assign bfs ids to each node index = 0 for i in range(root.level, 0, -1): for node in queue[i]: node.bfs_index = index index += 1 # Then traverse the tree again to fill in the adjacencies for i in range(root.level, 0, -1): for node in queue[i]: for child in node.children: adjacency_matrix[node.bfs_index, child.bfs_index] = 1 adjacency_matrix[child.bfs_index, node.bfs_index] = 1 def _reconstruct(lca_matrix): """ Does the actual heavy lifting of the adjacency matrix reconstruction. Traverses the LCA matrix level-by-level, starting at one. For each level new nodes have to be inserted into the adjacency matrix, if a LCA matrix with this level number exists. The newly created node(s) will then be connected to the lower leaves, respectively, sub-graphs. This function may produce reconstructions that are valid graphs, but not trees. """ n = lca_matrix.shape[0] total_nodes = n # depths = int(lca_matrix.max()) levels = sorted(lca_matrix.unique().tolist()) # Want to skip over leaves levels.remove(0) # Create nodes for all leaves leaves = [Node(1, [], lca_index=i) for i in range(n)] # Iterate level-by-level through the matrix, starting from immediate connections # we can correct missing intermediate levels here too # Just use current_level to check the actual LCA entry, once we know which level it is # (ignoring missed levels) then use the index (corrected level) # for current_level in range(1, depths + 1): for idx, current_level in enumerate(levels, 1): # Iterate through each leaf in the LCA matrix for column in range(n): # Iterate through all corresponding nodes # The LCA matrix is symmetric, hence, check only the from the diagonal down for row in range(column + 1, n): # Skip over entries not in current level if lca_matrix[row, column] <= 0: raise InvalidLCAMatrix elif lca_matrix[row, column] != current_level: continue # Get the nodes a_node = leaves[column] another_node = leaves[row] # Determine the ancestors of both nodes an_ancestor = _get_ancestor(a_node) a_level = an_ancestor.level another_ancestor = _get_ancestor(another_node) another_level = another_ancestor.level # The nodes both already have an ancestor at that level, confirm it's the same one # and check that the common ancestor doesn't have a child which is in turn an ancestor of both left and right nodes if a_level == another_level == (idx + 1): if ( an_ancestor is not another_ancestor or _nodes_in_ancestors_children( an_ancestor, a_node, another_node ) ): raise InvalidLCAMatrix # Should also check neither have an ancestor above the current level # If so then something went really wrong elif a_level > idx + 1 or another_level > idx + 1: raise InvalidLCAMatrix # The nodes don't have an ancestor at the level we're inspecting. # We need to make one and connect them to it elif a_level < idx + 1 and another_level < idx + 1: parent = Node(idx + 1, [an_ancestor, another_ancestor], lcas_level=current_level) an_ancestor.parent = parent another_ancestor.parent = parent total_nodes += 1 # the left node already has a higher order parent, lets attach to it # I think should confirm that a_level == idx + 1 too elif another_level < idx + 1 and a_level == idx + 1: # This should be the another_ancestor.parent getting assigned # another_node.parent = an_ancestor # an_ancestor.children.append(another_node) another_ancestor.parent = an_ancestor an_ancestor.children.append(another_ancestor) # Same for right elif a_level < idx + 1 and another_level == idx + 1: an_ancestor.parent = another_ancestor another_ancestor.children.append(an_ancestor) # If all this fails I think that's also bad else: raise InvalidLCAMatrix # The LCAs aren't guaranteed to actually be "lowest" ancestors, we need to make sure # by pulling down any nodes that can be (i.e. have all children more than one level down) for leaf in leaves: _pull_down(leaf) # We have created the tree structure, let's initialize the adjacency matrix and find the root to traverse from root = _get_ancestor(leaves[0]) return root, total_nodes
[docs]def lca_to_adjacency(lca_matrix): """ Converts a tree's LCA matrix representation, i.e. a square matrix (M, M) where each row/column corresponds to a leaf of the tree and each matrix entry is the level of the lowest-common-ancestor (LCA) of the two leaves, into the corresponding two-dimension adjacency matrix (N,N), with M < N. The levels are enumerated top-down from the root. .. seealso:: The pseudocode for LCA to tree conversion is described in `Kahn et al <https://iopscience.iop.org/article/10.1088/2632-2153/ac8de0>`_. :param lca_matrix: 2-dimensional LCA matrix (M, M). :type lca_matrix: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ :return: 2-dimensional matrix (N, N) encoding the graph's node adjacencies. Linked nodes have values unequal to zero. :rtype: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ Raises: InvalidLCAMatrix: If passed LCA matrix is malformed (e.g. not 2d or not square) or does not encode a tree. """ # Ensure input is torch tensor or can be converted to it if not isinstance(lca_matrix, t.Tensor): try: lca_matrix = t.Tensor(lca_matrix) except TypeError as err: print(f"Input type must be compatible with torch Tensor: {err}") raise # Ensure two dimensions if len(lca_matrix.shape) != 2: raise InvalidLCAMatrix # Ensure that it is square n, m = lca_matrix.shape if n != m: raise InvalidLCAMatrix # Check symmetry if not (lca_matrix == lca_matrix.T).all(): raise InvalidLCAMatrix try: root, total_nodes = _reconstruct(lca_matrix) except IndexError: raise InvalidLCAMatrix # Allocate the adjacency matrix adjacency_matrix = t.zeros((total_nodes, total_nodes), dtype=t.int64) try: _breadth_first_adjacency(root, adjacency_matrix) except IndexError: raise InvalidLCAMatrix # Check whether what we reconstructed is actually a tree - might be a regular graph for example if not is_valid_tree(adjacency_matrix): raise InvalidLCAMatrix return adjacency_matrix
def _get_fsps_of_node(node): """ Given a node, finds all the final state particles connected to it and get their indices in the LCA. Args: node (Node): Node to be inspected. Returns: indices (list): List of final state particles' indices in the LCA matrix connected to node. """ indices = [] if node.lca_index is not None: # If you simply use 'if node.lca_index:' you will always miss the first fsp indices.append(node.lca_index) else: for child in node.children: indices.extend(_get_fsps_of_node(child)) return list(set(indices))
[docs]def select_good_decay(predicted_lcas, predicted_masses, sig_side_lcas=None, sig_side_masses=None): """ Checks if given LCAS matrix is found in reconstructed LCAS matrix and mass hypotheses are correct. .. warning:: You have to make sure to call this function only for valid tree structures encoded in ``predicted_lcas``, otherwise it will throw an exception. Mass hypotheses are indicated by letters. The following convention is used: .. math:: 'e' \\to e \\\\ 'i' \\to \\pi \\\\ 'k' \\to K \\\\ 'p' \\to p \\\\ 'm' \\to \\mu \\\\ 'g' \\to \\gamma \\\\ 'o' \\to \\text{others} .. warning:: The order of mass hypotheses should match that of the final state particles in the LCAS. :param predicted_lcas: LCAS matrix. :type predicted_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ :param predicted_masses: List of predicted mass classes. :type predicted_masses: list[str] :param sig_side_lcas: LCAS matrix of your signal-side. :type sig_side_lcas: `Tensor <https://pytorch.org/docs/stable/tensors.html#torch.Tensor>`_ :param sig_side_masses: List of mass hypotheses for your FSPs. :type sig_side_masses: list[str] Returns: bool, int, list: True if LCAS and masses match, LCAS level of root node, LCA indices of FSPs belonging to the signal side ([-1] if LCAS does not match decay string). """ # Reconstruct decay chain root, _ = _reconstruct(predicted_lcas) # If root is not Ups nor B then decay is not good if root.lcas_level not in [5, 6]: return (False, root.lcas_level, [-1]) # If root is B don't go any further (function is supposed to check wheter signal-side on Ups decay is good) if root.lcas_level == 5: return (True, 5, [i for i in range(predicted_lcas.shape[0])]) # If chosen LCAS or masses are None then decay is not good if sig_side_lcas is None or sig_side_masses is None: return (False, root.lcas_level, [-1]) # Check if the LCA matrix/masses you chose are valid try: if sig_side_lcas.item() == 0: more_fsps = False else: raise InvalidLCAMatrix("If you have only one sig-side FSP, the LCA matrix should be [[0]]") except ValueError: try: lca_to_adjacency(sig_side_lcas) more_fsps = True except InvalidLCAMatrix: raise InvalidLCAMatrix("You chose an invalid LCA matrix") # Check if the number of FSPs in the LCA is the same as the number of mass hypotheses if sig_side_lcas.shape[0] != len(sig_side_masses): raise InvalidLCAMatrix("The dimension of the LCA matrix you chose does not match with the number of mass hypotheses") # Check if mass hypotheses are allowed for e in set(sig_side_masses): if e not in ['i', 'o', 'g', 'k', 'm', 'e', 'p']: # Ok this is not properly an InvalidLCAMatrix case but I'm too lazy to define dedicated exception raise InvalidLCAMatrix("Allowed mass hypotheses are 'i', 'o', 'g', 'k', 'm', 'e', 'p'") # Convert mass hypotheses to classes and then to integers for s, n in zip(["i", "k", "p", "e", "m", "g", "o"], ["3", "4", "5", "1", "2", "6", "0"]): sig_side_masses = list(map(lambda x: x.replace(s, n), sig_side_masses)) sig_side_masses = t.from_numpy(np.array(sig_side_masses, dtype=int)) # Let's start the proper decay check # Case 1: only one FSP in the signal-side if not more_fsps: # There should be two nodes: one '5' and one '0' if Counter([child.lcas_level for child in root.children]) != Counter({5: 1, 0: 1}): return (False, root.lcas_level, [-1]) # Get FSP index in LCA fsp_idx = root.children[0].lca_index if root.children[0].lcas_level == 0 else root.children[1].lca_index # Check mass hypothesis if predicted_masses[fsp_idx] != sig_side_masses[0]: return (False, root.lcas_level, [-1]) # I think the exceptions are over, decay is good return (True, root.lcas_level, [fsp_idx]) # Case 2: more FSPs in the signal-side else: # There should be two nodes labelled as '5' if Counter([child.lcas_level for child in root.children]) != Counter({5: 2}): return (False, root.lcas_level, [-1]) # If there are two '5', at least one of them should decay into the nodes given by the chosen LCAS/masses # Step 1: get LCA indices of both Bs B1_indices = _get_fsps_of_node(root.children[0]) B2_indices = _get_fsps_of_node(root.children[1]) # Step 2: Loop over the two Bs and select LCA sub-matrix and sub-masses for indices in [B1_indices, B2_indices]: # Step 3: check whether number of FSPs in the chosen sig-side corresponds to that of one of the B's if sig_side_lcas.shape[0] != len(indices): continue sub_lca = predicted_lcas[indices][:, indices] sub_masses = predicted_masses[indices] # Step 4: your chosen sig-side LCAS/masses could have different ordering, # we have to check all possible permutations for permutation in permutations(list(range(len(sub_lca)))): permutation = list(permutation) permuted_sig_side_lca = sig_side_lcas[permutation][:, permutation] permuted_sig_side_masses = sig_side_masses[permutation] # Step 5: if one of the permutations works decay is good if (permuted_sig_side_lca == sub_lca).all() and (permuted_sig_side_masses == sub_masses).all(): return (True, root.lcas_level, indices) # If we get here decay is not good return (False, root.lcas_level, [-1])