Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp"
0010 
0011 #include <map>
0012 
0013 #include <torch/script.h>
0014 
0015 #include "weaklyConnectedComponentsCugraph.hpp"
0016 
0017 namespace Acts {
0018 
0019 std::vector<std::vector<int>> CugraphTrackBuilding::operator()(
0020     std::any, std::any edges, std::any edge_weights,
0021     std::vector<int> &spacepointIDs, int) {
0022   auto numSpacepoints = spacepointIDs.size();
0023   auto edgesAfterFiltering = std::any_cast<std::vector<int64_t>>(edges);
0024   auto numEdgesAfterF = edgesAfterFiltering.size() / 2;
0025   auto gOutputCTen = std::any_cast<at::Tensor>(edge_weights);
0026 
0027   if (numEdgesAfterF == 0) {
0028     return {};
0029   }
0030 
0031   // ************
0032   // Track Labeling with cugraph::connected_components
0033   // ************
0034   std::vector<int32_t> rowIndices;
0035   std::vector<int32_t> colIndices;
0036   std::vector<float> edgeWeights;
0037   std::vector<int32_t> trackLabels(numSpacepoints);
0038   std::copy(edgesAfterFiltering.begin(),
0039             edgesAfterFiltering.begin() + numEdgesAfterF,
0040             std::back_insert_iterator(rowIndices));
0041   std::copy(edgesAfterFiltering.begin() + numEdgesAfterF,
0042             edgesAfterFiltering.end(), std::back_insert_iterator(colIndices));
0043   std::copy(gOutputCTen.data_ptr<float>(),
0044             gOutputCTen.data_ptr<float>() + numEdgesAfterF,
0045             std::back_insert_iterator(edgeWeights));
0046 
0047   ACTS_VERBOSE("run weaklyConnectedComponents");
0048   weaklyConnectedComponents<int32_t, int32_t, float>(
0049       rowIndices, colIndices, edgeWeights, trackLabels, logger());
0050 
0051   ACTS_DEBUG("size of components: " << trackLabels.size());
0052   if (trackLabels.size() == 0) {
0053     return {};
0054   }
0055 
0056   std::vector<std::vector<int>> trackCandidates;
0057   trackCandidates.clear();
0058 
0059   int existTrkIdx = 0;
0060   // map labeling from MCC to customized track id.
0061   std::map<int, int> trackLableToIds;
0062 
0063   for (auto idx = 0ul; idx < numSpacepoints; ++idx) {
0064     int trackLabel = trackLabels[idx];
0065     int spacepointID = spacepointIDs[idx];
0066 
0067     int trkId;
0068     if (trackLableToIds.find(trackLabel) != trackLableToIds.end()) {
0069       trkId = trackLableToIds[trackLabel];
0070       trackCandidates[trkId].push_back(spacepointID);
0071     } else {
0072       // a new track, assign the track id
0073       // and create a vector
0074       trkId = existTrkIdx;
0075       trackCandidates.push_back(std::vector<int>{trkId});
0076       trackLableToIds[trackLabel] = trkId;
0077       existTrkIdx++;
0078     }
0079   }
0080 
0081   return trackCandidates;
0082 }
0083 
0084 }  // namespace Acts