Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:09:49

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2022 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"
0010 
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
0013 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0014 #include "Acts/Utilities/Zip.hpp"
0015 #include "ActsExamples/EventData/Index.hpp"
0016 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0017 #include "ActsExamples/EventData/ProtoTrack.hpp"
0018 #include "ActsExamples/EventData/SimSpacePoint.hpp"
0019 #include "ActsExamples/Framework/WhiteBoard.hpp"
0020 
0021 #include <numeric>
0022 
0023 using namespace ActsExamples;
0024 using namespace Acts::UnitLiterals;
0025 
0026 namespace {
0027 
0028 class ExamplesEdmHook : public Acts::ExaTrkXHook {
0029   double m_targetPT = 0.5_GeV;
0030   std::size_t m_targetSize = 3;
0031 
0032   std::unique_ptr<const Acts::Logger> m_logger;
0033   std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
0034   std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;
0035   std::unique_ptr<Acts::TorchGraphStoreHook> m_graphStoreHook;
0036 
0037   const Acts::Logger& logger() const { return *m_logger; }
0038 
0039   struct HitInfo {
0040     std::size_t spacePointIndex;
0041     int32_t hitIndex;
0042   };
0043 
0044  public:
0045   ExamplesEdmHook(const SimSpacePointContainer& spacepoints,
0046                   const IndexMultimap<Index>& measHitMap,
0047                   const SimHitContainer& simhits,
0048                   const SimParticleContainer& particles,
0049                   std::size_t targetMinHits, double targetMinPT,
0050                   const Acts::Logger& logger)
0051       : m_targetPT(targetMinPT),
0052         m_targetSize(targetMinHits),
0053         m_logger(logger.clone("MetricsHook")) {
0054     // Associate tracks to graph, collect momentum
0055     std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;
0056 
0057     for (auto i = 0ul; i < spacepoints.size(); ++i) {
0058       const auto measId = spacepoints[i]
0059                               .sourceLinks()[0]
0060                               .template get<IndexSourceLink>()
0061                               .index();
0062 
0063       auto [a, b] = measHitMap.equal_range(measId);
0064       for (auto it = a; it != b; ++it) {
0065         const auto& hit = *simhits.nth(it->second);
0066 
0067         tracks[hit.particleId()].push_back({i, hit.index()});
0068       }
0069     }
0070 
0071     // Collect edges for truth graph and target graph
0072     std::vector<int64_t> truthGraph;
0073     std::vector<int64_t> targetGraph;
0074 
0075     for (auto& [pid, track] : tracks) {
0076       // Sort by hit index, so the edges are connected correctly
0077       std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) {
0078         return a.hitIndex < b.hitIndex;
0079       });
0080 
0081       auto found = particles.find(pid);
0082       if (found == particles.end()) {
0083         ACTS_WARNING("Did not find " << pid << ", skip track");
0084         continue;
0085       }
0086 
0087       for (auto i = 0ul; i < track.size() - 1; ++i) {
0088         truthGraph.push_back(track[i].spacePointIndex);
0089         truthGraph.push_back(track[i + 1].spacePointIndex);
0090 
0091         if (found->transverseMomentum() > m_targetPT &&
0092             track.size() >= m_targetSize) {
0093           targetGraph.push_back(track[i].spacePointIndex);
0094           targetGraph.push_back(track[i + 1].spacePointIndex);
0095         }
0096       }
0097     }
0098 
0099     m_truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
0100         truthGraph, logger.clone());
0101     m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
0102         targetGraph, logger.clone());
0103     m_graphStoreHook = std::make_unique<Acts::TorchGraphStoreHook>();
0104   }
0105 
0106   ~ExamplesEdmHook() {}
0107 
0108   auto storedGraph() const { return m_graphStoreHook->storedGraph(); }
0109 
0110   void operator()(const std::any& nodes, const std::any& edges,
0111                   const std::any& weights) const override {
0112     ACTS_INFO("Metrics for total graph:");
0113     (*m_truthGraphHook)(nodes, edges, weights);
0114     ACTS_INFO("Metrics for target graph (pT > "
0115               << m_targetPT / Acts::UnitConstants::GeV
0116               << " GeV, nHits >= " << m_targetSize << "):");
0117     (*m_targetGraphHook)(nodes, edges, weights);
0118     (*m_graphStoreHook)(nodes, edges, weights);
0119   }
0120 };
0121 
0122 }  // namespace
0123 
0124 ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(
0125     Config config, Acts::Logging::Level level)
0126     : ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
0127       m_cfg(std::move(config)),
0128       m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
0129                  m_cfg.trackBuilder, logger().clone()) {
0130   if (m_cfg.inputSpacePoints.empty()) {
0131     throw std::invalid_argument("Missing spacepoint input collection");
0132   }
0133   if (m_cfg.outputProtoTracks.empty()) {
0134     throw std::invalid_argument("Missing protoTrack output collection");
0135   }
0136 
0137   // Sanitizer run with dummy input to detect configuration issues
0138   // TODO This would be quite helpful I think, but currently it does not work
0139   // in general because the stages do not expose the number of node features.
0140   // However, this must be addressed anyway when we also want to allow to
0141   // configure this more flexible with e.g. cluster information as input. So
0142   // for now, we disable this.
0143 #if 0
0144   if( m_cfg.sanitize ) {
0145   Eigen::VectorXf dummyInput = Eigen::VectorXf::Random(3 * 15);
0146   std::vector<float> dummyInputVec(dummyInput.data(),
0147                                    dummyInput.data() + dummyInput.size());
0148   std::vector<int> spacepointIDs;
0149   std::iota(spacepointIDs.begin(), spacepointIDs.end(), 0);
0150   
0151   runPipeline(dummyInputVec, spacepointIDs);
0152   }
0153 #endif
0154 
0155   m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0156   m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0157   m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);
0158 
0159   m_inputSimHits.maybeInitialize(m_cfg.inputSimHits);
0160   m_inputParticles.maybeInitialize(m_cfg.inputParticles);
0161   m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap);
0162 
0163   m_outputGraph.maybeInitialize(m_cfg.outputGraph);
0164 
0165   // reserve space for timing
0166   m_timing.classifierTimes.resize(
0167       m_cfg.edgeClassifiers.size(),
0168       decltype(m_timing.classifierTimes)::value_type{0.f});
0169 }
0170 
0171 /// Allow access to features with nice names
0172 enum feat : std::size_t {
0173   eR = 0,
0174   ePhi,
0175   eZ,
0176   eCellCount,
0177   eCellSum,
0178   eClusterX,
0179   eClusterY
0180 };
0181 
0182 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
0183     const ActsExamples::AlgorithmContext& ctx) const {
0184   // Read input data
0185   auto spacepoints = m_inputSpacePoints(ctx);
0186 
0187   auto hook = std::make_unique<Acts::ExaTrkXHook>();
0188   if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) {
0189     hook = std::make_unique<ExamplesEdmHook>(
0190         spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx),
0191         m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT,
0192         logger());
0193   }
0194 
0195   std::optional<ClusterContainer> clusters;
0196   if (m_inputClusters.isInitialized()) {
0197     clusters = m_inputClusters(ctx);
0198   }
0199 
0200   // Convert Input data to a list of size [num_measurements x
0201   // measurement_features]
0202   const std::size_t numSpacepoints = spacepoints.size();
0203   const std::size_t numFeatures = clusters ? 7 : 3;
0204   ACTS_INFO("Received " << numSpacepoints << " spacepoints");
0205 
0206   std::vector<float> features(numSpacepoints * numFeatures);
0207   std::vector<int> spacepointIDs;
0208 
0209   spacepointIDs.reserve(spacepoints.size());
0210 
0211   double sumCells = 0.0;
0212   double sumActivation = 0.0;
0213 
0214   for (auto i = 0ul; i < numSpacepoints; ++i) {
0215     const auto& sp = spacepoints[i];
0216 
0217     // I would prefer to use a std::span or boost::span here once available
0218     float* featurePtr = features.data() + i * numFeatures;
0219 
0220     // For now just take the first index since does require one single index
0221     // per spacepoint
0222     const auto& sl = sp.sourceLinks()[0].template get<IndexSourceLink>();
0223     spacepointIDs.push_back(sl.index());
0224 
0225     featurePtr[eR] = std::hypot(sp.x(), sp.y()) / m_cfg.rScale;
0226     featurePtr[ePhi] = std::atan2(sp.y(), sp.x()) / m_cfg.phiScale;
0227     featurePtr[eZ] = sp.z() / m_cfg.zScale;
0228 
0229     if (clusters) {
0230       const auto& cluster = clusters->at(sl.index());
0231       const auto& chnls = cluster.channels;
0232 
0233       featurePtr[eCellCount] = cluster.channels.size() / m_cfg.cellCountScale;
0234       featurePtr[eCellSum] =
0235           std::accumulate(chnls.begin(), chnls.end(), 0.0,
0236                           [](double s, const Cluster::Cell& c) {
0237                             return s + c.activation;
0238                           }) /
0239           m_cfg.cellSumScale;
0240       featurePtr[eClusterX] = cluster.sizeLoc0 / m_cfg.clusterXScale;
0241       featurePtr[eClusterY] = cluster.sizeLoc1 / m_cfg.clusterYScale;
0242 
0243       sumCells += featurePtr[eCellCount];
0244       sumActivation += featurePtr[eCellSum];
0245     }
0246   }
0247 
0248   ACTS_DEBUG("Avg cell count: " << sumCells / spacepoints.size());
0249   ACTS_DEBUG("Avg activation: " << sumActivation / sumCells);
0250 
0251   // Run the pipeline
0252   const auto trackCandidates = [&]() {
0253     const int deviceHint = -1;
0254     std::lock_guard<std::mutex> lock(m_mutex);
0255 
0256     Acts::ExaTrkXTiming timing;
0257     auto res =
0258         m_pipeline.run(features, spacepointIDs, deviceHint, *hook, &timing);
0259 
0260     m_timing.graphBuildingTime(timing.graphBuildingTime.count());
0261 
0262     assert(timing.classifierTimes.size() == m_timing.classifierTimes.size());
0263     for (auto [aggr, a] :
0264          Acts::zip(m_timing.classifierTimes, timing.classifierTimes)) {
0265       aggr(a.count());
0266     }
0267 
0268     m_timing.trackBuildingTime(timing.trackBuildingTime.count());
0269 
0270     return res;
0271   }();
0272 
0273   ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
0274                                              << " candidates");
0275 
0276   // Make the prototracks
0277   std::vector<ProtoTrack> protoTracks;
0278   protoTracks.reserve(trackCandidates.size());
0279 
0280   int nShortTracks = 0;
0281 
0282   for (auto& x : trackCandidates) {
0283     if (m_cfg.filterShortTracks && x.size() < 3) {
0284       nShortTracks++;
0285       continue;
0286     }
0287 
0288     ProtoTrack onetrack;
0289     onetrack.reserve(x.size());
0290 
0291     std::copy(x.begin(), x.end(), std::back_inserter(onetrack));
0292     protoTracks.push_back(std::move(onetrack));
0293   }
0294 
0295   ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits");
0296   ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
0297   m_outputProtoTracks(ctx, std::move(protoTracks));
0298 
0299   if (auto dhook = dynamic_cast<ExamplesEdmHook*>(&*hook);
0300       dhook && m_outputGraph.isInitialized()) {
0301     auto graph = dhook->storedGraph();
0302     std::transform(
0303         graph.first.begin(), graph.first.end(), graph.first.begin(),
0304         [&](const auto& a) -> int64_t { return spacepointIDs.at(a); });
0305     m_outputGraph(ctx, std::move(graph));
0306   }
0307 
0308   return ActsExamples::ProcessCode::SUCCESS;
0309 }
0310 
0311 ActsExamples::ProcessCode TrackFindingAlgorithmExaTrkX::finalize() {
0312   namespace ba = boost::accumulators;
0313 
0314   ACTS_INFO("Exa.TrkX timing info");
0315   {
0316     const auto& t = m_timing.graphBuildingTime;
0317     ACTS_INFO("- graph building: " << ba::mean(t) << " +- "
0318                                    << std::sqrt(ba::variance(t)));
0319   }
0320   for (const auto& t : m_timing.classifierTimes) {
0321     ACTS_INFO("- classifier:     " << ba::mean(t) << " +- "
0322                                    << std::sqrt(ba::variance(t)));
0323   }
0324   {
0325     const auto& t = m_timing.trackBuildingTime;
0326     ACTS_INFO("- track building: " << ba::mean(t) << " +- "
0327                                    << std::sqrt(ba::variance(t)));
0328   }
0329 
0330   return {};
0331 }