Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp"
0010 
0011 #include "Acts/Utilities/Zip.hpp"
0012 
0013 #include <map>
0014 
0015 #include <boost/beast/core/span.hpp>
0016 #include <boost/graph/adjacency_list.hpp>
0017 #include <boost/graph/connected_components.hpp>
0018 #include <torch/torch.h>
0019 
0020 namespace {
0021 template <typename vertex_t, typename weight_t>
0022 auto weaklyConnectedComponents(vertex_t numNodes,
0023                                boost::beast::span<vertex_t>& rowIndices,
0024                                boost::beast::span<vertex_t>& colIndices,
0025                                boost::beast::span<weight_t>& edgeWeights,
0026                                std::vector<vertex_t>& trackLabels) {
0027   using Graph =
0028       boost::adjacency_list<boost::vecS,         // edge list
0029                             boost::vecS,         // vertex list
0030                             boost::undirectedS,  // directedness
0031                             boost::no_property,  // property of vertices
0032                             weight_t             // property of edges
0033                             >;
0034 
0035   Graph g(numNodes);
0036 
0037   for (const auto [row, col, weight] :
0038        Acts::zip(rowIndices, colIndices, edgeWeights)) {
0039     boost::add_edge(row, col, weight, g);
0040   }
0041 
0042   return boost::connected_components(g, &trackLabels[0]);
0043 }
0044 }  // namespace
0045 
0046 namespace Acts {
0047 
0048 std::vector<std::vector<int>> BoostTrackBuilding::operator()(
0049     std::any nodes, std::any edges, std::any weights,
0050     std::vector<int>& spacepointIDs, int) {
0051   ACTS_DEBUG("Start track building");
0052   const auto edgeTensor = std::any_cast<torch::Tensor>(edges).to(torch::kCPU);
0053   const auto edgeWeightTensor =
0054       std::any_cast<torch::Tensor>(weights).to(torch::kCPU);
0055 
0056   assert(edgeTensor.size(0) == 2);
0057   assert(edgeTensor.size(1) == edgeWeightTensor.size(0));
0058 
0059   const auto numSpacepoints = spacepointIDs.size();
0060   const auto numEdges = static_cast<std::size_t>(edgeWeightTensor.size(0));
0061 
0062   if (numEdges == 0) {
0063     ACTS_WARNING("No edges remained after edge classification");
0064     return {};
0065   }
0066 
0067   using vertex_t = int64_t;
0068   using weight_t = float;
0069 
0070   boost::beast::span<vertex_t> rowIndices(edgeTensor.data_ptr<vertex_t>(),
0071                                           numEdges);
0072   boost::beast::span<vertex_t> colIndices(
0073       edgeTensor.data_ptr<vertex_t>() + numEdges, numEdges);
0074   boost::beast::span<weight_t> edgeWeights(edgeWeightTensor.data_ptr<float>(),
0075                                            numEdges);
0076 
0077   std::vector<vertex_t> trackLabels(numSpacepoints);
0078 
0079   auto numberLabels = weaklyConnectedComponents<vertex_t, weight_t>(
0080       numSpacepoints, rowIndices, colIndices, edgeWeights, trackLabels);
0081 
0082   ACTS_VERBOSE("Number of track labels: " << trackLabels.size());
0083   ACTS_VERBOSE("Number of unique track labels: " << [&]() {
0084     std::vector<vertex_t> sorted(trackLabels);
0085     std::sort(sorted.begin(), sorted.end());
0086     sorted.erase(std::unique(sorted.begin(), sorted.end()), sorted.end());
0087     return sorted.size();
0088   }());
0089 
0090   if (trackLabels.size() == 0) {
0091     return {};
0092   }
0093 
0094   std::vector<std::vector<int>> trackCandidates(numberLabels);
0095 
0096   for (const auto [label, id] : Acts::zip(trackLabels, spacepointIDs)) {
0097     trackCandidates[label].push_back(id);
0098   }
0099 
0100   return trackCandidates;
0101 }
0102 
0103 }  // namespace Acts