File indexing completed on 2025-08-05 08:09:49
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"
0010
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/Plugins/ExaTrkX/TorchGraphStoreHook.hpp"
0013 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0014 #include "Acts/Utilities/Zip.hpp"
0015 #include "ActsExamples/EventData/Index.hpp"
0016 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0017 #include "ActsExamples/EventData/ProtoTrack.hpp"
0018 #include "ActsExamples/EventData/SimSpacePoint.hpp"
0019 #include "ActsExamples/Framework/WhiteBoard.hpp"
0020
0021 #include <numeric>
0022
0023 using namespace ActsExamples;
0024 using namespace Acts::UnitLiterals;
0025
0026 namespace {
0027
0028 class ExamplesEdmHook : public Acts::ExaTrkXHook {
0029 double m_targetPT = 0.5_GeV;
0030 std::size_t m_targetSize = 3;
0031
0032 std::unique_ptr<const Acts::Logger> m_logger;
0033 std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_truthGraphHook;
0034 std::unique_ptr<Acts::TorchTruthGraphMetricsHook> m_targetGraphHook;
0035 std::unique_ptr<Acts::TorchGraphStoreHook> m_graphStoreHook;
0036
0037 const Acts::Logger& logger() const { return *m_logger; }
0038
0039 struct HitInfo {
0040 std::size_t spacePointIndex;
0041 int32_t hitIndex;
0042 };
0043
0044 public:
0045 ExamplesEdmHook(const SimSpacePointContainer& spacepoints,
0046 const IndexMultimap<Index>& measHitMap,
0047 const SimHitContainer& simhits,
0048 const SimParticleContainer& particles,
0049 std::size_t targetMinHits, double targetMinPT,
0050 const Acts::Logger& logger)
0051 : m_targetPT(targetMinPT),
0052 m_targetSize(targetMinHits),
0053 m_logger(logger.clone("MetricsHook")) {
0054
0055 std::unordered_map<ActsFatras::Barcode, std::vector<HitInfo>> tracks;
0056
0057 for (auto i = 0ul; i < spacepoints.size(); ++i) {
0058 const auto measId = spacepoints[i]
0059 .sourceLinks()[0]
0060 .template get<IndexSourceLink>()
0061 .index();
0062
0063 auto [a, b] = measHitMap.equal_range(measId);
0064 for (auto it = a; it != b; ++it) {
0065 const auto& hit = *simhits.nth(it->second);
0066
0067 tracks[hit.particleId()].push_back({i, hit.index()});
0068 }
0069 }
0070
0071
0072 std::vector<int64_t> truthGraph;
0073 std::vector<int64_t> targetGraph;
0074
0075 for (auto& [pid, track] : tracks) {
0076
0077 std::sort(track.begin(), track.end(), [](const auto& a, const auto& b) {
0078 return a.hitIndex < b.hitIndex;
0079 });
0080
0081 auto found = particles.find(pid);
0082 if (found == particles.end()) {
0083 ACTS_WARNING("Did not find " << pid << ", skip track");
0084 continue;
0085 }
0086
0087 for (auto i = 0ul; i < track.size() - 1; ++i) {
0088 truthGraph.push_back(track[i].spacePointIndex);
0089 truthGraph.push_back(track[i + 1].spacePointIndex);
0090
0091 if (found->transverseMomentum() > m_targetPT &&
0092 track.size() >= m_targetSize) {
0093 targetGraph.push_back(track[i].spacePointIndex);
0094 targetGraph.push_back(track[i + 1].spacePointIndex);
0095 }
0096 }
0097 }
0098
0099 m_truthGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
0100 truthGraph, logger.clone());
0101 m_targetGraphHook = std::make_unique<Acts::TorchTruthGraphMetricsHook>(
0102 targetGraph, logger.clone());
0103 m_graphStoreHook = std::make_unique<Acts::TorchGraphStoreHook>();
0104 }
0105
0106 ~ExamplesEdmHook() {}
0107
0108 auto storedGraph() const { return m_graphStoreHook->storedGraph(); }
0109
0110 void operator()(const std::any& nodes, const std::any& edges,
0111 const std::any& weights) const override {
0112 ACTS_INFO("Metrics for total graph:");
0113 (*m_truthGraphHook)(nodes, edges, weights);
0114 ACTS_INFO("Metrics for target graph (pT > "
0115 << m_targetPT / Acts::UnitConstants::GeV
0116 << " GeV, nHits >= " << m_targetSize << "):");
0117 (*m_targetGraphHook)(nodes, edges, weights);
0118 (*m_graphStoreHook)(nodes, edges, weights);
0119 }
0120 };
0121
0122 }
0123
0124 ActsExamples::TrackFindingAlgorithmExaTrkX::TrackFindingAlgorithmExaTrkX(
0125 Config config, Acts::Logging::Level level)
0126 : ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
0127 m_cfg(std::move(config)),
0128 m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
0129 m_cfg.trackBuilder, logger().clone()) {
0130 if (m_cfg.inputSpacePoints.empty()) {
0131 throw std::invalid_argument("Missing spacepoint input collection");
0132 }
0133 if (m_cfg.outputProtoTracks.empty()) {
0134 throw std::invalid_argument("Missing protoTrack output collection");
0135 }
0136
0137
0138
0139
0140
0141
0142
0143 #if 0
0144 if( m_cfg.sanitize ) {
0145 Eigen::VectorXf dummyInput = Eigen::VectorXf::Random(3 * 15);
0146 std::vector<float> dummyInputVec(dummyInput.data(),
0147 dummyInput.data() + dummyInput.size());
0148 std::vector<int> spacepointIDs;
0149 std::iota(spacepointIDs.begin(), spacepointIDs.end(), 0);
0150
0151 runPipeline(dummyInputVec, spacepointIDs);
0152 }
0153 #endif
0154
0155 m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0156 m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0157 m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);
0158
0159 m_inputSimHits.maybeInitialize(m_cfg.inputSimHits);
0160 m_inputParticles.maybeInitialize(m_cfg.inputParticles);
0161 m_inputMeasurementMap.maybeInitialize(m_cfg.inputMeasurementSimhitsMap);
0162
0163 m_outputGraph.maybeInitialize(m_cfg.outputGraph);
0164
0165
0166 m_timing.classifierTimes.resize(
0167 m_cfg.edgeClassifiers.size(),
0168 decltype(m_timing.classifierTimes)::value_type{0.f});
0169 }
0170
0171
0172 enum feat : std::size_t {
0173 eR = 0,
0174 ePhi,
0175 eZ,
0176 eCellCount,
0177 eCellSum,
0178 eClusterX,
0179 eClusterY
0180 };
0181
0182 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmExaTrkX::execute(
0183 const ActsExamples::AlgorithmContext& ctx) const {
0184
0185 auto spacepoints = m_inputSpacePoints(ctx);
0186
0187 auto hook = std::make_unique<Acts::ExaTrkXHook>();
0188 if (m_inputSimHits.isInitialized() && m_inputMeasurementMap.isInitialized()) {
0189 hook = std::make_unique<ExamplesEdmHook>(
0190 spacepoints, m_inputMeasurementMap(ctx), m_inputSimHits(ctx),
0191 m_inputParticles(ctx), m_cfg.targetMinHits, m_cfg.targetMinPT,
0192 logger());
0193 }
0194
0195 std::optional<ClusterContainer> clusters;
0196 if (m_inputClusters.isInitialized()) {
0197 clusters = m_inputClusters(ctx);
0198 }
0199
0200
0201
0202 const std::size_t numSpacepoints = spacepoints.size();
0203 const std::size_t numFeatures = clusters ? 7 : 3;
0204 ACTS_INFO("Received " << numSpacepoints << " spacepoints");
0205
0206 std::vector<float> features(numSpacepoints * numFeatures);
0207 std::vector<int> spacepointIDs;
0208
0209 spacepointIDs.reserve(spacepoints.size());
0210
0211 double sumCells = 0.0;
0212 double sumActivation = 0.0;
0213
0214 for (auto i = 0ul; i < numSpacepoints; ++i) {
0215 const auto& sp = spacepoints[i];
0216
0217
0218 float* featurePtr = features.data() + i * numFeatures;
0219
0220
0221
0222 const auto& sl = sp.sourceLinks()[0].template get<IndexSourceLink>();
0223 spacepointIDs.push_back(sl.index());
0224
0225 featurePtr[eR] = std::hypot(sp.x(), sp.y()) / m_cfg.rScale;
0226 featurePtr[ePhi] = std::atan2(sp.y(), sp.x()) / m_cfg.phiScale;
0227 featurePtr[eZ] = sp.z() / m_cfg.zScale;
0228
0229 if (clusters) {
0230 const auto& cluster = clusters->at(sl.index());
0231 const auto& chnls = cluster.channels;
0232
0233 featurePtr[eCellCount] = cluster.channels.size() / m_cfg.cellCountScale;
0234 featurePtr[eCellSum] =
0235 std::accumulate(chnls.begin(), chnls.end(), 0.0,
0236 [](double s, const Cluster::Cell& c) {
0237 return s + c.activation;
0238 }) /
0239 m_cfg.cellSumScale;
0240 featurePtr[eClusterX] = cluster.sizeLoc0 / m_cfg.clusterXScale;
0241 featurePtr[eClusterY] = cluster.sizeLoc1 / m_cfg.clusterYScale;
0242
0243 sumCells += featurePtr[eCellCount];
0244 sumActivation += featurePtr[eCellSum];
0245 }
0246 }
0247
0248 ACTS_DEBUG("Avg cell count: " << sumCells / spacepoints.size());
0249 ACTS_DEBUG("Avg activation: " << sumActivation / sumCells);
0250
0251
0252 const auto trackCandidates = [&]() {
0253 const int deviceHint = -1;
0254 std::lock_guard<std::mutex> lock(m_mutex);
0255
0256 Acts::ExaTrkXTiming timing;
0257 auto res =
0258 m_pipeline.run(features, spacepointIDs, deviceHint, *hook, &timing);
0259
0260 m_timing.graphBuildingTime(timing.graphBuildingTime.count());
0261
0262 assert(timing.classifierTimes.size() == m_timing.classifierTimes.size());
0263 for (auto [aggr, a] :
0264 Acts::zip(m_timing.classifierTimes, timing.classifierTimes)) {
0265 aggr(a.count());
0266 }
0267
0268 m_timing.trackBuildingTime(timing.trackBuildingTime.count());
0269
0270 return res;
0271 }();
0272
0273 ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
0274 << " candidates");
0275
0276
0277 std::vector<ProtoTrack> protoTracks;
0278 protoTracks.reserve(trackCandidates.size());
0279
0280 int nShortTracks = 0;
0281
0282 for (auto& x : trackCandidates) {
0283 if (m_cfg.filterShortTracks && x.size() < 3) {
0284 nShortTracks++;
0285 continue;
0286 }
0287
0288 ProtoTrack onetrack;
0289 onetrack.reserve(x.size());
0290
0291 std::copy(x.begin(), x.end(), std::back_inserter(onetrack));
0292 protoTracks.push_back(std::move(onetrack));
0293 }
0294
0295 ACTS_INFO("Removed " << nShortTracks << " with less then 3 hits");
0296 ACTS_INFO("Created " << protoTracks.size() << " proto tracks");
0297 m_outputProtoTracks(ctx, std::move(protoTracks));
0298
0299 if (auto dhook = dynamic_cast<ExamplesEdmHook*>(&*hook);
0300 dhook && m_outputGraph.isInitialized()) {
0301 auto graph = dhook->storedGraph();
0302 std::transform(
0303 graph.first.begin(), graph.first.end(), graph.first.begin(),
0304 [&](const auto& a) -> int64_t { return spacepointIDs.at(a); });
0305 m_outputGraph(ctx, std::move(graph));
0306 }
0307
0308 return ActsExamples::ProcessCode::SUCCESS;
0309 }
0310
0311 ActsExamples::ProcessCode TrackFindingAlgorithmExaTrkX::finalize() {
0312 namespace ba = boost::accumulators;
0313
0314 ACTS_INFO("Exa.TrkX timing info");
0315 {
0316 const auto& t = m_timing.graphBuildingTime;
0317 ACTS_INFO("- graph building: " << ba::mean(t) << " +- "
0318 << std::sqrt(ba::variance(t)));
0319 }
0320 for (const auto& t : m_timing.classifierTimes) {
0321 ACTS_INFO("- classifier: " << ba::mean(t) << " +- "
0322 << std::sqrt(ba::variance(t)));
0323 }
0324 {
0325 const auto& t = m_timing.trackBuildingTime;
0326 ACTS_INFO("- track building: " << ba::mean(t) << " +- "
0327 << std::sqrt(ba::variance(t)));
0328 }
0329
0330 return {};
0331 }