Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 
0013 #include <torch/torch.h>
0014 
0015 namespace {
0016 
0017 auto cantorize(std::vector<int64_t> edgeIndex, const Acts::Logger& logger) {
0018   // Use cantor pairing to store truth graph, so we can easily use set
0019   // operations to compute efficiency and purity
0020   std::vector<Acts::detail::CantorEdge<int64_t>> cantorEdgeIndex;
0021   cantorEdgeIndex.reserve(edgeIndex.size() / 2);
0022   for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) {
0023     cantorEdgeIndex.emplace_back(*it, *std::next(it));
0024   }
0025 
0026   std::sort(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0027 
0028   auto new_end = std::unique(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0029   if (new_end != cantorEdgeIndex.end()) {
0030     ACTS_WARNING("Graph not unique ("
0031                  << std::distance(new_end, cantorEdgeIndex.end())
0032                  << " duplicates)");
0033     cantorEdgeIndex.erase(new_end, cantorEdgeIndex.end());
0034   }
0035 
0036   return cantorEdgeIndex;
0037 }
0038 
0039 }  // namespace
0040 
0041 Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook(
0042     const std::vector<int64_t>& truthGraph,
0043     std::unique_ptr<const Acts::Logger> l)
0044     : m_logger(std::move(l)) {
0045   m_truthGraphCantor = cantorize(truthGraph, logger());
0046 }
0047 
0048 void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&,
0049                                                   const std::any& edges,
0050                                                   const std::any&) const {
0051   // We need to transpose the edges here for the right memory layout
0052   const auto edgeIndex = Acts::detail::tensor2DToVector<int64_t>(
0053       std::any_cast<torch::Tensor>(edges).t());
0054 
0055   auto predGraphCantor = cantorize(edgeIndex, logger());
0056 
0057   // Calculate intersection
0058   std::vector<Acts::detail::CantorEdge<int64_t>> intersection;
0059   intersection.reserve(
0060       std::max(predGraphCantor.size(), m_truthGraphCantor.size()));
0061 
0062   std::set_intersection(predGraphCantor.begin(), predGraphCantor.end(),
0063                         m_truthGraphCantor.begin(), m_truthGraphCantor.end(),
0064                         std::back_inserter(intersection));
0065 
0066   ACTS_DEBUG("Intersection size " << intersection.size());
0067   const float intersectionSizeFloat = intersection.size();
0068   const float eff = intersectionSizeFloat / m_truthGraphCantor.size();
0069   const float pur = intersectionSizeFloat / predGraphCantor.size();
0070 
0071   ACTS_INFO("Efficiency=" << eff << ", purity=" << pur);
0072 }