Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0010 
0011 namespace Acts {
0012 
0013 ExaTrkXPipeline::ExaTrkXPipeline(
0014     std::shared_ptr<GraphConstructionBase> graphConstructor,
0015     std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0016     std::shared_ptr<TrackBuildingBase> trackBuilder,
0017     std::unique_ptr<const Acts::Logger> logger)
0018     : m_logger(std::move(logger)),
0019       m_graphConstructor(graphConstructor),
0020       m_edgeClassifiers(edgeClassifiers),
0021       m_trackBuilder(trackBuilder) {
0022   if (!m_graphConstructor) {
0023     throw std::invalid_argument("Missing graph construction module");
0024   }
0025   if (!m_trackBuilder) {
0026     throw std::invalid_argument("Missing track building module");
0027   }
0028   if (m_edgeClassifiers.empty() or
0029       not std::all_of(m_edgeClassifiers.begin(), m_edgeClassifiers.end(),
0030                       [](const auto &a) { return static_cast<bool>(a); })) {
0031     throw std::invalid_argument("Missing graph construction module");
0032   }
0033 }
0034 
0035 std::vector<std::vector<int>> ExaTrkXPipeline::run(
0036     std::vector<float> &features, std::vector<int> &spacepointIDs,
0037     int deviceHint, const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
0038   auto t0 = std::chrono::high_resolution_clock::now();
0039   auto [nodes, edges] =
0040       (*m_graphConstructor)(features, spacepointIDs.size(), deviceHint);
0041   auto t1 = std::chrono::high_resolution_clock::now();
0042 
0043   if (timing != nullptr) {
0044     timing->graphBuildingTime = t1 - t0;
0045   }
0046 
0047   hook(nodes, edges, {});
0048 
0049   std::any edge_weights;
0050   timing->classifierTimes.clear();
0051 
0052   for (auto edgeClassifier : m_edgeClassifiers) {
0053     t0 = std::chrono::high_resolution_clock::now();
0054     auto [newNodes, newEdges, newWeights] =
0055         (*edgeClassifier)(std::move(nodes), std::move(edges), deviceHint);
0056     t1 = std::chrono::high_resolution_clock::now();
0057 
0058     if (timing != nullptr) {
0059       timing->classifierTimes.push_back(t1 - t0);
0060     }
0061 
0062     nodes = std::move(newNodes);
0063     edges = std::move(newEdges);
0064     edge_weights = std::move(newWeights);
0065 
0066     hook(nodes, edges, edge_weights);
0067   }
0068 
0069   t0 = std::chrono::high_resolution_clock::now();
0070   auto res =
0071       (*m_trackBuilder)(std::move(nodes), std::move(edges),
0072                         std::move(edge_weights), spacepointIDs, deviceHint);
0073   t1 = std::chrono::high_resolution_clock::now();
0074 
0075   if (timing != nullptr) {
0076     timing->trackBuildingTime = t1 - t0;
0077   }
0078 
0079   return res;
0080 }
0081 
0082 }  // namespace Acts