Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-06 08:10:47

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2021 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/Io/Csv/CsvTrackWriter.hpp"
0010 
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/EventData/MultiTrajectory.hpp"
0013 #include "Acts/EventData/ProxyAccessor.hpp"
0014 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0015 #include "Acts/Utilities/Helpers.hpp"
0016 #include "Acts/Utilities/MultiIndex.hpp"
0017 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0018 #include "ActsExamples/EventData/Track.hpp"
0019 #include "ActsExamples/Framework/AlgorithmContext.hpp"
0020 #include "ActsExamples/Utilities/Paths.hpp"
0021 #include "ActsExamples/Utilities/Range.hpp"
0022 #include "ActsExamples/Validation/TrackClassification.hpp"
0023 
0024 #include <algorithm>
0025 #include <fstream>
0026 #include <iomanip>
0027 #include <map>
0028 #include <memory>
0029 #include <stdexcept>
0030 #include <string>
0031 #include <unordered_map>
0032 #include <unordered_set>
0033 #include <utility>
0034 
0035 namespace ActsExamples {
0036 class IndexSourceLink;
0037 }  // namespace ActsExamples
0038 
0039 using namespace ActsExamples;
0040 
0041 CsvTrackWriter::CsvTrackWriter(const CsvTrackWriter::Config& config,
0042                                Acts::Logging::Level level)
0043     : WriterT<ConstTrackContainer>(config.inputTracks, "CsvTrackWriter", level),
0044       m_cfg(config) {
0045   if (m_cfg.inputTracks.empty()) {
0046     throw std::invalid_argument("Missing input tracks collection");
0047   }
0048 
0049   m_inputMeasurementParticlesMap.initialize(m_cfg.inputMeasurementParticlesMap);
0050 }
0051 
0052 ProcessCode CsvTrackWriter::writeT(const AlgorithmContext& context,
0053                                    const ConstTrackContainer& tracks) {
0054   // open per-event file
0055   std::string path =
0056       perEventFilepath(m_cfg.outputDir, m_cfg.fileName, context.eventNumber);
0057   std::ofstream mos(path, std::ofstream::out | std::ofstream::trunc);
0058   if (!mos) {
0059     throw std::ios_base::failure("Could not open '" + path + "' to write");
0060   }
0061 
0062   const auto& hitParticlesMap = m_inputMeasurementParticlesMap(context);
0063 
0064   std::unordered_map<Acts::MultiTrajectoryTraits::IndexType, TrackInfo> infoMap;
0065 
0066   // Counter of truth-matched reco tracks
0067   using RecoTrackInfo = std::pair<TrackInfo, std::size_t>;
0068   std::map<ActsFatras::Barcode, std::vector<RecoTrackInfo>> matched;
0069 
0070   std::size_t trackId = 0;
0071   for (const auto& track : tracks) {
0072     // Reco track selection
0073     //@TODO: add interface for applying others cuts on reco tracks:
0074     // -> pT, d0, z0, detector-specific hits/holes number cut
0075     if (track.nMeasurements() < m_cfg.nMeasurementsMin) {
0076       continue;
0077     }
0078 
0079     // Check if the reco track has fitted track parameters
0080     if (!track.hasReferenceSurface()) {
0081       ACTS_WARNING(
0082           "No fitted track parameters for trajectory with entry index = "
0083           << track.tipIndex());
0084       continue;
0085     }
0086 
0087     // Get the majority truth particle to this track
0088     std::vector<ParticleHitCount> particleHitCount;
0089     identifyContributingParticles(hitParticlesMap, track, particleHitCount);
0090     if (m_cfg.onlyTruthMatched && particleHitCount.empty()) {
0091       ACTS_WARNING(
0092           "No truth particle associated with this trajectory with entry "
0093           "index = "
0094           << track.tipIndex());
0095       continue;
0096     }
0097 
0098     // Requirement on the pT of the track
0099     auto params = track.createParametersAtReference();
0100     const auto momentum = params.momentum();
0101     const auto pT = Acts::VectorHelpers::perp(momentum);
0102     if (pT < m_cfg.ptMin) {
0103       continue;
0104     }
0105     std::size_t nMajorityHits = 0;
0106     ActsFatras::Barcode majorityParticleId;
0107     if (!particleHitCount.empty()) {
0108       // Get the majority particle counts
0109       majorityParticleId = particleHitCount.front().particleId;
0110       // n Majority hits
0111       nMajorityHits = particleHitCount.front().hitCount;
0112     }
0113 
0114     static const Acts::ConstProxyAccessor<unsigned int> seedNumber(
0115         "trackGroup");
0116 
0117     // track info
0118     TrackInfo toAdd;
0119     toAdd.trackId = trackId;
0120     if (tracks.hasColumn(Acts::hashString("trackGroup"))) {
0121       toAdd.seedID = seedNumber(track);
0122     } else {
0123       toAdd.seedID = 0;
0124     }
0125     toAdd.particleId = majorityParticleId;
0126     toAdd.nStates = track.nTrackStates();
0127     toAdd.nMajorityHits = nMajorityHits;
0128     toAdd.nMeasurements = track.nMeasurements();
0129     toAdd.nOutliers = track.nOutliers();
0130     toAdd.nHoles = track.nHoles();
0131     toAdd.nSharedHits = track.nSharedHits();
0132     toAdd.chi2Sum = track.chi2();
0133     toAdd.NDF = track.nDoF();
0134     toAdd.truthMatchProb = toAdd.nMajorityHits * 1. / track.nMeasurements();
0135     toAdd.fittedParameters = params;
0136     toAdd.trackType = "unknown";
0137 
0138     for (const auto& state : track.trackStatesReversed()) {
0139       if (state.typeFlags().test(Acts::TrackStateFlag::MeasurementFlag)) {
0140         auto sl =
0141             state.getUncalibratedSourceLink().template get<IndexSourceLink>();
0142         auto hitIndex = sl.index();
0143         toAdd.measurementsID.insert(toAdd.measurementsID.begin(), hitIndex);
0144       }
0145     }
0146 
0147     // Check if the trajectory is matched with truth.
0148     if (toAdd.truthMatchProb >= m_cfg.truthMatchProbMin) {
0149       matched[toAdd.particleId].push_back({toAdd, toAdd.trackId});
0150     } else {
0151       toAdd.trackType = "fake";
0152     }
0153 
0154     infoMap[toAdd.trackId] = toAdd;
0155 
0156     trackId++;
0157   }
0158 
0159   // Find duplicates
0160   std::unordered_set<std::size_t> listGoodTracks;
0161   for (auto& [particleId, matchedTracks] : matched) {
0162     std::sort(matchedTracks.begin(), matchedTracks.end(),
0163               [](const RecoTrackInfo& lhs, const RecoTrackInfo& rhs) {
0164                 // sort by nMajorityHits
0165                 if (lhs.first.nMajorityHits != rhs.first.nMajorityHits) {
0166                   return (lhs.first.nMajorityHits > rhs.first.nMajorityHits);
0167                 }
0168                 // sort by nOutliers
0169                 if (lhs.first.nOutliers != rhs.first.nOutliers) {
0170                   return (lhs.first.nOutliers < rhs.first.nOutliers);
0171                 }
0172                 // sort by chi2
0173                 return (lhs.first.chi2Sum < rhs.first.chi2Sum);
0174               });
0175 
0176     listGoodTracks.insert(matchedTracks.front().first.trackId);
0177   }
0178 
0179   // write csv header
0180   mos << "track_id,seed_id,particleId,"
0181       << "nStates,nMajorityHits,nMeasurements,nOutliers,nHoles,nSharedHits,"
0182       << "chi2,ndf,chi2/ndf,"
0183       << "pT,eta,phi,"
0184       << "truthMatchProbability,"
0185       << "good/duplicate/fake,"
0186       << "Hits_ID";
0187 
0188   mos << '\n';
0189   mos << std::setprecision(m_cfg.outputPrecision);
0190 
0191   // good/duplicate/fake = 0/1/2
0192   for (auto& [id, trajState] : infoMap) {
0193     if (listGoodTracks.find(id) != listGoodTracks.end()) {
0194       trajState.trackType = "good";
0195     } else if (trajState.trackType != "fake") {
0196       trajState.trackType = "duplicate";
0197     }
0198 
0199     const auto& params = *trajState.fittedParameters;
0200     const auto momentum = params.momentum();
0201 
0202     // write the track info
0203     mos << trajState.trackId << ",";
0204     mos << trajState.seedID << ",";
0205     mos << trajState.particleId << ",";
0206     mos << trajState.nStates << ",";
0207     mos << trajState.nMajorityHits << ",";
0208     mos << trajState.nMeasurements << ",";
0209     mos << trajState.nOutliers << ",";
0210     mos << trajState.nHoles << ",";
0211     mos << trajState.nSharedHits << ",";
0212     mos << trajState.chi2Sum << ",";
0213     mos << trajState.NDF << ",";
0214     mos << trajState.chi2Sum * 1.0 / trajState.NDF << ",";
0215     mos << Acts::VectorHelpers::perp(momentum) << ",";
0216     mos << Acts::VectorHelpers::eta(momentum) << ",";
0217     mos << Acts::VectorHelpers::phi(momentum) << ",";
0218     mos << trajState.truthMatchProb << ",";
0219     mos << trajState.trackType << ",";
0220     mos << "\"[";
0221     for (auto& ID : trajState.measurementsID) {
0222       mos << ID << ",";
0223     }
0224     mos << "]\"";
0225     mos << '\n';
0226   }
0227 
0228   return ProcessCode::SUCCESS;
0229 }