Source code for grafei.model.tree_utils

##########################################################################
# basf2 (Belle II Analysis Software Framework)                           #
# Author: The Belle II Collaboration                                     #
#                                                                        #
# See git log for contributors and copyright holders.                    #
# This file is licensed under LGPL-3.0, see LICENSE.md.                  #
##########################################################################


import torch as t
import numpy as np


[docs]def masses_to_classes(array): """ Converts mass hypotheses to classes used in cross-entropy computation. Classes are: .. math:: e \\to 1\\\\ \\mu \\to 2\\\\ \\pi \\to 3\\\\ K \\to 4\\\\ p \\to 5\\\\ \\gamma \\to 6\\\\ \\text{others} \\to 0 Args: array (numpy.ndarray): Array containing PDG mass codes. Returns: numpy.ndarray: Array containing mass hypothese converted to classes. """ array = -1 * np.abs(array) # All elements become negative array[array == -11] = 1 # Electrons are class 1 array[array == -13] = 2 # Muons are class 2 array[array == -211] = 3 # Pions are class 3 array[array == -321] = 4 # Kaons are class 4 array[array == -2212] = 5 # Protons are class 5 array[array == -22] = 6 # Photons are class 6 array[array <= 0] = 0 # Others are 0 return array
def _check_undirected(adjacency_matrix): """ Checks whether an adjacency matrix-encoded graph is undirected, i.e. symmetric. """ n, m = adjacency_matrix.shape if n != m: return False return (adjacency_matrix == adjacency_matrix.T).all() def _connectedness_dfs(adjacency_matrix, index, reached): """ Actual depth-first search of graph connectedness. Starting from the node marked by index a recursive search is performed. Visited nodes are marked as reachable during recursion. If the graph is not connected, the reachability `reached` mask will contain zero elements. """ n = adjacency_matrix.shape[0] reached[index] = 1 # Traverse through all for column in range(n): # Recursively search for connectedness nodes that are adjacent and avoid nodes already marked as reachable if adjacency_matrix[index, column] != 0 and not reached[column]: _connectedness_dfs(adjacency_matrix, column, reached) def _check_connectedness(adjacency_matrix, allow_disconnected_leaves=False): """ Checks whether all sub-graphs of an adjacency matrix-encoded graph are connected, i.e. have at least one edge linking them. """ n, m = adjacency_matrix.shape if n != m: return False reached = t.zeros(n, dtype=t.uint8) _connectedness_dfs(adjacency_matrix, 0, reached) if allow_disconnected_leaves: reached = t.logical_or(reached, adjacency_matrix.sum(axis=1) == 0) return reached.all() def _acyclic_dfs(adjacency_matrix, index, parent, reached): """ Actual depth-first search of graph cycles. Starting from the node marked by index a recursive search is performed. Visited nodes are marked as reachable during recursion. If a node is found in a trail that has been previously marked as already reached this indicates a cycle. """ n = adjacency_matrix.shape[0] reached[index] = 1 for row in range(n): # the passed adjacency matrix may contain self-references # while technically not acyclic, these are allowed, if row == index: continue if adjacency_matrix[index, row] != 0: if not reached[row]: # cycle if not _acyclic_dfs(adjacency_matrix, row, index, reached): return False elif row != parent: # cycle return False return True def _check_acyclic(adjacency_matrix): """ Checks whether the graph encoded by the passed adjacency matrix is acyclic, i.e. all non-empty trails in the graph do not contain repetitions. Node self-references are legal and simply ignored. """ n, m = adjacency_matrix.shape if n != m: return False reached = t.zeros(n, dtype=t.uint8) return _acyclic_dfs(adjacency_matrix, 0, -1, reached)
[docs]def is_valid_tree(adjacency_matrix): """ Checks whether the graph encoded by the passed adjacency matrix encodes a valid tree, i.e. an undirected, acyclic and connected graph. Args: adjacency_matrix (numpy.ndarray): 2-dimensional matrix (N, N) encoding the graph's node adjacencies. Linked nodes should have value unequal to zero. Returns: bool: True if the encoded graph is a tree, False otherwise. """ undirected = _check_undirected(adjacency_matrix) connected = _check_connectedness(adjacency_matrix) acyclic = _check_acyclic(adjacency_matrix) return undirected and connected and acyclic