File indexing completed on 2025-08-05 08:10:17
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012
0013 #include <torch/torch.h>
0014
0015 namespace {
0016
0017 auto cantorize(std::vector<int64_t> edgeIndex, const Acts::Logger& logger) {
0018
0019
0020 std::vector<Acts::detail::CantorEdge<int64_t>> cantorEdgeIndex;
0021 cantorEdgeIndex.reserve(edgeIndex.size() / 2);
0022 for (auto it = edgeIndex.begin(); it != edgeIndex.end(); it += 2) {
0023 cantorEdgeIndex.emplace_back(*it, *std::next(it));
0024 }
0025
0026 std::sort(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0027
0028 auto new_end = std::unique(cantorEdgeIndex.begin(), cantorEdgeIndex.end());
0029 if (new_end != cantorEdgeIndex.end()) {
0030 ACTS_WARNING("Graph not unique ("
0031 << std::distance(new_end, cantorEdgeIndex.end())
0032 << " duplicates)");
0033 cantorEdgeIndex.erase(new_end, cantorEdgeIndex.end());
0034 }
0035
0036 return cantorEdgeIndex;
0037 }
0038
0039 }
0040
0041 Acts::TorchTruthGraphMetricsHook::TorchTruthGraphMetricsHook(
0042 const std::vector<int64_t>& truthGraph,
0043 std::unique_ptr<const Acts::Logger> l)
0044 : m_logger(std::move(l)) {
0045 m_truthGraphCantor = cantorize(truthGraph, logger());
0046 }
0047
0048 void Acts::TorchTruthGraphMetricsHook::operator()(const std::any&,
0049 const std::any& edges,
0050 const std::any&) const {
0051
0052 const auto edgeIndex = Acts::detail::tensor2DToVector<int64_t>(
0053 std::any_cast<torch::Tensor>(edges).t());
0054
0055 auto predGraphCantor = cantorize(edgeIndex, logger());
0056
0057
0058 std::vector<Acts::detail::CantorEdge<int64_t>> intersection;
0059 intersection.reserve(
0060 std::max(predGraphCantor.size(), m_truthGraphCantor.size()));
0061
0062 std::set_intersection(predGraphCantor.begin(), predGraphCantor.end(),
0063 m_truthGraphCantor.begin(), m_truthGraphCantor.end(),
0064 std::back_inserter(intersection));
0065
0066 ACTS_DEBUG("Intersection size " << intersection.size());
0067 const float intersectionSizeFloat = intersection.size();
0068 const float eff = intersectionSizeFloat / m_truthGraphCantor.size();
0069 const float pur = intersectionSizeFloat / predGraphCantor.size();
0070
0071 ACTS_INFO("Efficiency=" << eff << ", purity=" << pur);
0072 }