File indexing completed on 2025-08-05 08:10:17
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0010
0011 namespace Acts {
0012
0013 ExaTrkXPipeline::ExaTrkXPipeline(
0014 std::shared_ptr<GraphConstructionBase> graphConstructor,
0015 std::vector<std::shared_ptr<EdgeClassificationBase>> edgeClassifiers,
0016 std::shared_ptr<TrackBuildingBase> trackBuilder,
0017 std::unique_ptr<const Acts::Logger> logger)
0018 : m_logger(std::move(logger)),
0019 m_graphConstructor(graphConstructor),
0020 m_edgeClassifiers(edgeClassifiers),
0021 m_trackBuilder(trackBuilder) {
0022 if (!m_graphConstructor) {
0023 throw std::invalid_argument("Missing graph construction module");
0024 }
0025 if (!m_trackBuilder) {
0026 throw std::invalid_argument("Missing track building module");
0027 }
0028 if (m_edgeClassifiers.empty() or
0029 not std::all_of(m_edgeClassifiers.begin(), m_edgeClassifiers.end(),
0030 [](const auto &a) { return static_cast<bool>(a); })) {
0031 throw std::invalid_argument("Missing graph construction module");
0032 }
0033 }
0034
0035 std::vector<std::vector<int>> ExaTrkXPipeline::run(
0036 std::vector<float> &features, std::vector<int> &spacepointIDs,
0037 int deviceHint, const ExaTrkXHook &hook, ExaTrkXTiming *timing) const {
0038 auto t0 = std::chrono::high_resolution_clock::now();
0039 auto [nodes, edges] =
0040 (*m_graphConstructor)(features, spacepointIDs.size(), deviceHint);
0041 auto t1 = std::chrono::high_resolution_clock::now();
0042
0043 if (timing != nullptr) {
0044 timing->graphBuildingTime = t1 - t0;
0045 }
0046
0047 hook(nodes, edges, {});
0048
0049 std::any edge_weights;
0050 timing->classifierTimes.clear();
0051
0052 for (auto edgeClassifier : m_edgeClassifiers) {
0053 t0 = std::chrono::high_resolution_clock::now();
0054 auto [newNodes, newEdges, newWeights] =
0055 (*edgeClassifier)(std::move(nodes), std::move(edges), deviceHint);
0056 t1 = std::chrono::high_resolution_clock::now();
0057
0058 if (timing != nullptr) {
0059 timing->classifierTimes.push_back(t1 - t0);
0060 }
0061
0062 nodes = std::move(newNodes);
0063 edges = std::move(newEdges);
0064 edge_weights = std::move(newWeights);
0065
0066 hook(nodes, edges, edge_weights);
0067 }
0068
0069 t0 = std::chrono::high_resolution_clock::now();
0070 auto res =
0071 (*m_trackBuilder)(std::move(nodes), std::move(edges),
0072 std::move(edge_weights), spacepointIDs, deviceHint);
0073 t1 = std::chrono::high_resolution_clock::now();
0074
0075 if (timing != nullptr) {
0076 timing->trackBuildingTime = t1 - t0;
0077 }
0078
0079 return res;
0080 }
0081
0082 }