Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-06 08:18:25

0001 #include "PHActsGSF.h"
0002 #include "MakeSourceLinks.h"
0003 
0004 #include <fun4all/Fun4AllReturnCodes.h>
0005 #include <phool/PHCompositeNode.h>
0006 #include <phool/PHDataNode.h>
0007 #include <phool/PHNode.h>
0008 #include <phool/PHNodeIterator.h>
0009 #include <phool/PHObject.h>
0010 #include <phool/PHTimer.h>
0011 #include <phool/getClass.h>
0012 #include <phool/phool.h>
0013 
0014 #include <trackbase/ActsGeometry.h>
0015 #include <trackbase/ActsGsfTrackFittingAlgorithm.h>
0016 #include <trackbase/TpcDefs.h>
0017 #include <trackbase/TrkrCluster.h>
0018 #include <trackbase/TrkrClusterContainer.h>
0019 #include <trackbase/TrkrDefs.h>
0020 
0021 #include <trackbase_historic/ActsTransformations.h>
0022 #include <trackbase_historic/SvtxTrack.h>
0023 #include <trackbase_historic/SvtxTrackMap.h>
0024 #include <trackbase_historic/SvtxTrackMap_v2.h>
0025 #include <trackbase_historic/SvtxTrackState_v1.h>
0026 #include <trackbase_historic/TrackSeed_v2.h>
0027 #include <trackbase_historic/TrackSeedHelper.h>
0028 
0029 #include <globalvertex/SvtxVertex.h>
0030 #include <globalvertex/SvtxVertexMap.h>
0031 
0032 #include <Acts/EventData/MultiTrajectory.hpp>
0033 #include <Acts/EventData/MultiTrajectoryHelpers.hpp>
0034 #include <Acts/EventData/SourceLink.hpp>
0035 #include <Acts/EventData/TrackParameters.hpp>
0036 #include <Acts/Surfaces/PerigeeSurface.hpp>
0037 #include <Acts/Surfaces/PlaneSurface.hpp>
0038 #include <Acts/Surfaces/Surface.hpp>
0039 #include <Acts/TrackFitting/BetheHeitlerApprox.hpp>
0040 #include <Acts/TrackFitting/GainMatrixSmoother.hpp>
0041 #include <Acts/TrackFitting/GainMatrixUpdater.hpp>
0042 
0043 #include "ActsEvaluator.h"
0044 
0045 #include <TDatabasePDG.h>
0046 
0047 //____________________________________________________________________________..
0048 PHActsGSF::PHActsGSF(const std::string& name)
0049   : SubsysReco(name)
0050 {
0051 }
0052 
0053 //____________________________________________________________________________..
0054 int PHActsGSF::InitRun(PHCompositeNode* topNode)
0055 {
0056   if (Verbosity() > 1)
0057   {
0058     std::cout << "PHActsGSF::InitRun begin" << std::endl;
0059   }
0060 
0061   if (m_actsEvaluator)
0062   {
0063     PHNodeIterator iter(topNode);
0064 
0065     PHCompositeNode* dstNode = dynamic_cast<PHCompositeNode*>(iter.findFirst("PHCompositeNode", "DST"));
0066 
0067     if (!dstNode)
0068     {
0069       std::cerr << "DST node is missing, quitting" << std::endl;
0070       throw std::runtime_error("Failed to find DST node in PHActsTrkFitter::createNodes");
0071     }
0072 
0073     PHNodeIterator dstIter(topNode);
0074     PHCompositeNode* svtxNode = dynamic_cast<PHCompositeNode*>(dstIter.findFirst("PHCompositeNode", "SVTX"));
0075 
0076     if (!svtxNode)
0077     {
0078       svtxNode = new PHCompositeNode("SVTX");
0079       dstNode->addNode(svtxNode);
0080     }
0081     m_seedTracks = findNode::getClass<SvtxTrackMap>(topNode, _seed_track_map_name);
0082 
0083     if (!m_seedTracks)
0084     {
0085       m_seedTracks = new SvtxTrackMap_v2;
0086 
0087       PHIODataNode<PHObject>* seedNode =
0088           new PHIODataNode<PHObject>(m_seedTracks, _seed_track_map_name, "PHObject");
0089       svtxNode->addNode(seedNode);
0090     }
0091   }
0092 
0093   if (getNodes(topNode) != Fun4AllReturnCodes::EVENT_OK)
0094   {
0095     return Fun4AllReturnCodes::ABORTEVENT;
0096   }
0097 
0098   auto bha = Acts::makeDefaultBetheHeitlerApprox();
0099   ActsGsfTrackFittingAlgorithm gsf;
0100   m_fitCfg.fit = gsf.makeGsfFitterFunction(
0101       m_tGeometry->geometry().tGeometry,
0102       m_tGeometry->geometry().magField,
0103       bha,
0104       12, 1e-4, 
0105       MixtureReductionAlgorithm::KLDistance, false, false);
0106 
0107   if (m_actsEvaluator)
0108   {
0109     m_evaluator = std::make_unique<ActsEvaluator>(m_evalname);
0110     m_evaluator->Init(topNode);
0111     m_evaluator->verbosity(Verbosity());
0112   }
0113 
0114   return Fun4AllReturnCodes::EVENT_OK;
0115 }
0116 
0117 //____________________________________________________________________________..
0118 int PHActsGSF::process_event(PHCompositeNode* topNode)
0119 {
0120   auto logLevel = Acts::Logging::FATAL;
0121 
0122   if (m_actsEvaluator)
0123   {
0124     m_evaluator->next_event(topNode);
0125   }
0126 
0127   if (Verbosity() > 4)
0128   {
0129     logLevel = Acts::Logging::VERBOSE;
0130   }
0131 
0132   /// Fill an additional track map if using the acts evaluator
0133   /// for proto track comparison to fitted track
0134   if (m_actsEvaluator)
0135   {
0136     /// wipe at the beginning of every new fit pass, so that the seeds
0137     /// are whatever is currently in SvtxTrackMap
0138     m_seedTracks->clear();
0139     for (const auto& [key, track] : *m_trackMap)
0140     {
0141       m_seedTracks->insert(track);
0142     }
0143   }
0144 
0145   auto logger = Acts::getDefaultLogger("PHActsGSF", logLevel);
0146 
0147   for (const auto& [key, track] : *m_trackMap)
0148   {
0149     auto pSurface = makePerigee(track);
0150     if (!pSurface)
0151     {
0152       //! If no vertex was assigned to track, just skip it
0153       continue;
0154     }
0155     const auto seed = makeSeed(track, pSurface);
0156 
0157     auto svtxseed = new TrackSeed_v2();
0158     std::map<TrkrDefs::cluskey, Acts::Vector3> clusterPositions;
0159     for (auto& cKey : get_cluster_keys(track))
0160     {
0161       auto cluster = m_clusterContainer->findCluster(cKey);
0162       auto globalPosition = m_tGeometry->getGlobalPosition(cKey, cluster);
0163       clusterPositions.insert(std::make_pair(cKey, globalPosition));
0164       svtxseed->insert_cluster_key(cKey);
0165     }
0166     svtxseed->set_phi(track->get_phi());
0167     TrackSeedHelper::circleFitByTaubin(svtxseed,clusterPositions, 0, 57);
0168     TrackSeedHelper::lineFit(svtxseed, clusterPositions, 7, 57);
0169 
0170     ActsTrackFittingAlgorithm::MeasurementContainer measurements;
0171     TrackSeed* tpcseed = track->get_tpc_seed();
0172     TrackSeed* silseed = track->get_silicon_seed();
0173 
0174     /// We only fit full sPHENIX tracks
0175     if (!silseed or !tpcseed)
0176     {
0177       continue;
0178     }
0179 
0180     auto crossing = silseed->get_crossing();
0181     if (crossing == SHRT_MAX)
0182     {
0183       continue;
0184     }
0185 
0186     /*
0187     auto sourceLinks = getSourceLinks(tpcseed, measurements, crossing);
0188     auto silSourceLinks = getSourceLinks(silseed, measurements, crossing);
0189     */
0190 
0191     // loop over modifiedTransformSet and replace transient elements modified for the previous track with the default transforms
0192     MakeSourceLinks makeSourceLinks;
0193     makeSourceLinks.setVerbosity(Verbosity());
0194     makeSourceLinks.set_pp_mode(m_pp_mode);
0195 
0196     makeSourceLinks.resetTransientTransformMap(
0197         m_alignmentTransformationMapTransient,
0198         m_transient_id_set,
0199         m_tGeometry);
0200 
0201     // TPC source links
0202     auto sourceLinks = makeSourceLinks.getSourceLinks(
0203         tpcseed,
0204         measurements,
0205         m_clusterContainer,
0206         m_tGeometry,
0207         m_globalPositionWrapper,
0208         m_alignmentTransformationMapTransient,
0209         m_transient_id_set,
0210         crossing);
0211 
0212     // silicon source links
0213     auto silSourceLinks = makeSourceLinks.getSourceLinks(
0214         silseed,
0215         measurements,
0216         m_clusterContainer,
0217         m_tGeometry,
0218         m_globalPositionWrapper,
0219         m_alignmentTransformationMapTransient,
0220         m_transient_id_set,
0221         crossing);
0222 
0223     // copy transient map for this track into transient geoContext
0224     m_transient_geocontext = m_alignmentTransformationMapTransient;
0225 
0226     for (auto& siSL : silSourceLinks)
0227     {
0228       sourceLinks.push_back(siSL);
0229     }
0230 
0231     auto calibptr = std::make_unique<Calibrator>();
0232     CalibratorAdapter calibrator(*calibptr, measurements);
0233     auto magcontext = m_tGeometry->geometry().magFieldContext;
0234     auto calcontext = m_tGeometry->geometry().calibContext;
0235 
0236     auto ppoptions = Acts::PropagatorPlainOptions();
0237 
0238     ActsTrackFittingAlgorithm::GeneralFitterOptions options{
0239         m_transient_geocontext,
0240         magcontext,
0241         calcontext,
0242         &(*pSurface),
0243         ppoptions};
0244     if (Verbosity() > 2)
0245     {
0246       std::cout << "calling gsf with position "
0247                 << seed.position(m_transient_geocontext).transpose()
0248                 << " and momentum " << seed.momentum().transpose()
0249                 << std::endl;
0250     }
0251     auto trackContainer = std::make_shared<Acts::VectorTrackContainer>();
0252     auto trackStateContainer = std::make_shared<Acts::VectorMultiTrajectory>();
0253     ActsTrackFittingAlgorithm::TrackContainer tracks(trackContainer, trackStateContainer);
0254     auto result = fitTrack(sourceLinks, seed, options, calibrator, tracks);
0255 
0256     if (result.ok())
0257     {
0258       updateTrack(result, track, tracks, svtxseed, measurements);
0259     }
0260   }
0261 
0262   return Fun4AllReturnCodes::EVENT_OK;
0263 }
0264 
0265 std::shared_ptr<Acts::PerigeeSurface> PHActsGSF::makePerigee(SvtxTrack* track) const
0266 {
0267   SvtxVertex* vertex = m_vertexMap->get(track->get_vertex_id());
0268   if (!vertex)
0269   {
0270     return nullptr;
0271   }
0272 
0273   Acts::Vector3 vertexpos(vertex->get_x() * Acts::UnitConstants::cm,
0274                           vertex->get_y() * Acts::UnitConstants::cm,
0275                           vertex->get_z() * Acts::UnitConstants::cm);
0276 
0277   return Acts::Surface::makeShared<Acts::PerigeeSurface>(
0278       vertexpos);
0279 }
0280 
0281 ActsTrackFittingAlgorithm::TrackParameters PHActsGSF::makeSeed(SvtxTrack* track,
0282                                                                const std::shared_ptr<Acts::PerigeeSurface>& psurf) const
0283 {
0284   Acts::Vector4 fourpos(track->get_x() * Acts::UnitConstants::cm,
0285                         track->get_y() * Acts::UnitConstants::cm,
0286                         track->get_z() * Acts::UnitConstants::cm,
0287                         10 * Acts::UnitConstants::ns);
0288 
0289   int charge = track->get_charge();
0290   Acts::Vector3 momentum(track->get_px(),
0291                          track->get_py(),
0292                          track->get_pz());
0293 
0294   ActsTransformations transformer;
0295   auto cov = transformer.rotateSvtxTrackCovToActs(track);
0296 
0297   return ActsTrackFittingAlgorithm::TrackParameters::create(psurf,
0298                                                             m_tGeometry->geometry().getGeoContext(),
0299                                                             fourpos,
0300                                                             momentum,
0301                                                             charge / momentum.norm(),
0302                                                             cov,
0303                                                             Acts::ParticleHypothesis::electron(),
0304                                                             1 * Acts::UnitConstants::cm)
0305       .value();
0306 }
0307 
0308 ActsTrackFittingAlgorithm::TrackFitterResult PHActsGSF::fitTrack(
0309     const std::vector<Acts::SourceLink>& sourceLinks,
0310     const ActsTrackFittingAlgorithm::TrackParameters& seed,
0311     const ActsTrackFittingAlgorithm::GeneralFitterOptions& options,
0312     const CalibratorAdapter& calibrator,
0313     ActsTrackFittingAlgorithm::TrackContainer& tracks)
0314 {
0315   return (*m_fitCfg.fit)(sourceLinks, seed, options, calibrator, tracks);
0316 }
0317 
0318 void PHActsGSF::updateTrack(FitResult& result, SvtxTrack* track,
0319                             ActsTrackFittingAlgorithm::TrackContainer& tracks,
0320                             const TrackSeed* seed,
0321                             const ActsTrackFittingAlgorithm::MeasurementContainer& measurements)
0322 {
0323   std::vector<Acts::MultiTrajectoryTraits::IndexType> trackTips;
0324   trackTips.reserve(1);
0325   auto& outtrack = result.value();
0326   trackTips.emplace_back(outtrack.tipIndex());
0327   ActsExamples::Trajectories::IndexedParameters indexedParams;
0328 
0329   indexedParams.emplace(std::pair{outtrack.tipIndex(),
0330                                   ActsExamples::TrackParameters{outtrack.referenceSurface().getSharedPtr(),
0331                                                                 outtrack.parameters(), outtrack.covariance(), Acts::ParticleHypothesis::electron()}});
0332 
0333   updateSvtxTrack(trackTips, indexedParams, tracks, track);
0334 
0335   if (m_actsEvaluator)
0336   {
0337     m_evaluator->evaluateTrackFit(tracks, trackTips, indexedParams, track,
0338                                   seed, measurements);
0339   }
0340 }
0341 
0342 void PHActsGSF::updateSvtxTrack(std::vector<Acts::MultiTrajectoryTraits::IndexType>& tips,
0343                                 Trajectory::IndexedParameters& paramsMap,
0344                                 ActsTrackFittingAlgorithm::TrackContainer& tracks,
0345                                 SvtxTrack* track)
0346 {
0347   const auto& mj = tracks.trackStateContainer();
0348   const auto& tracktip = tips.front();
0349   const auto& params = paramsMap.find(tracktip)->second;
0350   const auto trajState =
0351       Acts::MultiTrajectoryHelpers::trajectoryState(mj, tracktip);
0352 
0353   if (Verbosity() > 1)
0354   {
0355     std::cout << "Old track parameters: " << std::endl
0356               << "   (" << track->get_x()
0357               << ", " << track->get_y() << ", " << track->get_z()
0358               << ")" << std::endl
0359               << "   (" << track->get_px() << ", " << track->get_py()
0360               << ", " << track->get_pz() << ")" << std::endl;
0361     std::cout << "New GSF track parameters: " << std::endl
0362               << "   " << params.position(m_transient_geocontext).transpose()
0363               << std::endl
0364               << "   " << params.momentum().transpose()
0365               << std::endl;
0366   }
0367 
0368   /// Will create new states
0369   track->clear_states();
0370 
0371   // create a state at pathlength = 0.0
0372   // This state holds the track parameters, which will be updated below
0373   float pathlength = 0.0;
0374   SvtxTrackState_v1 out(pathlength);
0375   out.set_x(0.0);
0376   out.set_y(0.0);
0377   out.set_z(0.0);
0378   track->insert_state(&out);
0379 
0380   track->set_x(params.position(m_transient_geocontext)(0) / Acts::UnitConstants::cm);
0381   track->set_y(params.position(m_transient_geocontext)(1) / Acts::UnitConstants::cm);
0382   track->set_z(params.position(m_transient_geocontext)(2) / Acts::UnitConstants::cm);
0383 
0384   track->set_px(params.momentum()(0));
0385   track->set_py(params.momentum()(1));
0386   track->set_pz(params.momentum()(2));
0387   track->set_charge(params.charge());
0388   track->set_chisq(trajState.chi2Sum);
0389   track->set_ndf(trajState.NDF);
0390 
0391   ActsTransformations transformer;
0392   transformer.setVerbosity(Verbosity());
0393 
0394   if (params.covariance())
0395   {
0396     auto rotatedCov = transformer.rotateActsCovToSvtxTrack(params);
0397     for (int i = 0; i < 6; i++)
0398     {
0399       for (int j = 0; j < 6; j++)
0400       {
0401         track->set_error(i, j, rotatedCov(i, j));
0402       }
0403     }
0404   }
0405 
0406   transformer.fillSvtxTrackStates(mj, tracktip, track, m_transient_geocontext);
0407 }
0408 
0409 //____________________________________________________________________________..
0410 std::vector<TrkrDefs::cluskey> PHActsGSF::get_cluster_keys(SvtxTrack* track)
0411 {
0412   std::vector<TrkrDefs::cluskey> out;
0413   for (const auto& seed : {track->get_silicon_seed(), track->get_tpc_seed()})
0414   {
0415     if (seed)
0416     {
0417       std::copy(seed->begin_cluster_keys(), seed->end_cluster_keys(), std::back_inserter(out));
0418     }
0419   }
0420 
0421   return out;
0422 }
0423 
0424 //____________________________________________________________________________..
0425 int PHActsGSF::End(PHCompositeNode* /*unused*/)
0426 {
0427   if (m_actsEvaluator)
0428   {
0429     m_evaluator->End();
0430   }
0431   return Fun4AllReturnCodes::EVENT_OK;
0432 }
0433 
0434 //____________________________________________________________________________..
0435 int PHActsGSF::getNodes(PHCompositeNode* topNode)
0436 {
0437   // track map
0438   m_trackMap = findNode::getClass<SvtxTrackMap>(topNode, m_trackMapName);
0439   if (!m_trackMap)
0440   {
0441     std::cout << PHWHERE << " The input track map is not available. Exiting PHActsGSF" << std::endl;
0442     return Fun4AllReturnCodes::ABORTEVENT;
0443   }
0444 
0445   // cluster map
0446   m_clusterContainer = findNode::getClass<TrkrClusterContainer>(topNode, "TRKR_CLUSTER");
0447   if (!m_clusterContainer)
0448   {
0449     std::cout << PHWHERE << "The input cluster container is not available. Exiting PHActsGSF" << std::endl;
0450     return Fun4AllReturnCodes::ABORTEVENT;
0451   }
0452 
0453   // acts geometry
0454   m_tGeometry = findNode::getClass<ActsGeometry>(topNode, "ActsGeometry");
0455   if (!m_tGeometry)
0456   {
0457     std::cout << PHWHERE << "The input Acts tracking geometry is not available. Exiting PHActsGSF" << std::endl;
0458     return Fun4AllReturnCodes::ABORTEVENT;
0459   }
0460 
0461   // global position wrapper
0462   m_globalPositionWrapper.loadNodes(topNode);
0463 
0464   // vertex map
0465   m_vertexMap = findNode::getClass<SvtxVertexMap>(topNode, "SvtxVertexMap");
0466   if (!m_vertexMap)
0467   {
0468     std::cout << PHWHERE << "Vertex map unavailable, exiting PHActsGSF" << std::endl;
0469     return Fun4AllReturnCodes::ABORTEVENT;
0470   }
0471 
0472   m_alignmentTransformationMapTransient = findNode::getClass<alignmentTransformationContainer>(topNode, "alignmentTransformationContainerTransient");
0473   if (!m_alignmentTransformationMapTransient)
0474   {
0475     std::cout << PHWHERE << "alignmentTransformationContainerTransient not on node tree. Bailing"
0476               << std::endl;
0477     return Fun4AllReturnCodes::ABORTEVENT;
0478   }
0479 
0480   return Fun4AllReturnCodes::EVENT_OK;
0481 }