File indexing completed on 2025-08-06 08:18:25
0001 #include "PHActsGSF.h"
0002 #include "MakeSourceLinks.h"
0003
0004 #include <fun4all/Fun4AllReturnCodes.h>
0005 #include <phool/PHCompositeNode.h>
0006 #include <phool/PHDataNode.h>
0007 #include <phool/PHNode.h>
0008 #include <phool/PHNodeIterator.h>
0009 #include <phool/PHObject.h>
0010 #include <phool/PHTimer.h>
0011 #include <phool/getClass.h>
0012 #include <phool/phool.h>
0013
0014 #include <trackbase/ActsGeometry.h>
0015 #include <trackbase/ActsGsfTrackFittingAlgorithm.h>
0016 #include <trackbase/TpcDefs.h>
0017 #include <trackbase/TrkrCluster.h>
0018 #include <trackbase/TrkrClusterContainer.h>
0019 #include <trackbase/TrkrDefs.h>
0020
0021 #include <trackbase_historic/ActsTransformations.h>
0022 #include <trackbase_historic/SvtxTrack.h>
0023 #include <trackbase_historic/SvtxTrackMap.h>
0024 #include <trackbase_historic/SvtxTrackMap_v2.h>
0025 #include <trackbase_historic/SvtxTrackState_v1.h>
0026 #include <trackbase_historic/TrackSeed_v2.h>
0027 #include <trackbase_historic/TrackSeedHelper.h>
0028
0029 #include <globalvertex/SvtxVertex.h>
0030 #include <globalvertex/SvtxVertexMap.h>
0031
0032 #include <Acts/EventData/MultiTrajectory.hpp>
0033 #include <Acts/EventData/MultiTrajectoryHelpers.hpp>
0034 #include <Acts/EventData/SourceLink.hpp>
0035 #include <Acts/EventData/TrackParameters.hpp>
0036 #include <Acts/Surfaces/PerigeeSurface.hpp>
0037 #include <Acts/Surfaces/PlaneSurface.hpp>
0038 #include <Acts/Surfaces/Surface.hpp>
0039 #include <Acts/TrackFitting/BetheHeitlerApprox.hpp>
0040 #include <Acts/TrackFitting/GainMatrixSmoother.hpp>
0041 #include <Acts/TrackFitting/GainMatrixUpdater.hpp>
0042
0043 #include "ActsEvaluator.h"
0044
0045 #include <TDatabasePDG.h>
0046
0047
0048 PHActsGSF::PHActsGSF(const std::string& name)
0049 : SubsysReco(name)
0050 {
0051 }
0052
0053
0054 int PHActsGSF::InitRun(PHCompositeNode* topNode)
0055 {
0056 if (Verbosity() > 1)
0057 {
0058 std::cout << "PHActsGSF::InitRun begin" << std::endl;
0059 }
0060
0061 if (m_actsEvaluator)
0062 {
0063 PHNodeIterator iter(topNode);
0064
0065 PHCompositeNode* dstNode = dynamic_cast<PHCompositeNode*>(iter.findFirst("PHCompositeNode", "DST"));
0066
0067 if (!dstNode)
0068 {
0069 std::cerr << "DST node is missing, quitting" << std::endl;
0070 throw std::runtime_error("Failed to find DST node in PHActsTrkFitter::createNodes");
0071 }
0072
0073 PHNodeIterator dstIter(topNode);
0074 PHCompositeNode* svtxNode = dynamic_cast<PHCompositeNode*>(dstIter.findFirst("PHCompositeNode", "SVTX"));
0075
0076 if (!svtxNode)
0077 {
0078 svtxNode = new PHCompositeNode("SVTX");
0079 dstNode->addNode(svtxNode);
0080 }
0081 m_seedTracks = findNode::getClass<SvtxTrackMap>(topNode, _seed_track_map_name);
0082
0083 if (!m_seedTracks)
0084 {
0085 m_seedTracks = new SvtxTrackMap_v2;
0086
0087 PHIODataNode<PHObject>* seedNode =
0088 new PHIODataNode<PHObject>(m_seedTracks, _seed_track_map_name, "PHObject");
0089 svtxNode->addNode(seedNode);
0090 }
0091 }
0092
0093 if (getNodes(topNode) != Fun4AllReturnCodes::EVENT_OK)
0094 {
0095 return Fun4AllReturnCodes::ABORTEVENT;
0096 }
0097
0098 auto bha = Acts::makeDefaultBetheHeitlerApprox();
0099 ActsGsfTrackFittingAlgorithm gsf;
0100 m_fitCfg.fit = gsf.makeGsfFitterFunction(
0101 m_tGeometry->geometry().tGeometry,
0102 m_tGeometry->geometry().magField,
0103 bha,
0104 12, 1e-4,
0105 MixtureReductionAlgorithm::KLDistance, false, false);
0106
0107 if (m_actsEvaluator)
0108 {
0109 m_evaluator = std::make_unique<ActsEvaluator>(m_evalname);
0110 m_evaluator->Init(topNode);
0111 m_evaluator->verbosity(Verbosity());
0112 }
0113
0114 return Fun4AllReturnCodes::EVENT_OK;
0115 }
0116
0117
0118 int PHActsGSF::process_event(PHCompositeNode* topNode)
0119 {
0120 auto logLevel = Acts::Logging::FATAL;
0121
0122 if (m_actsEvaluator)
0123 {
0124 m_evaluator->next_event(topNode);
0125 }
0126
0127 if (Verbosity() > 4)
0128 {
0129 logLevel = Acts::Logging::VERBOSE;
0130 }
0131
0132
0133
0134 if (m_actsEvaluator)
0135 {
0136
0137
0138 m_seedTracks->clear();
0139 for (const auto& [key, track] : *m_trackMap)
0140 {
0141 m_seedTracks->insert(track);
0142 }
0143 }
0144
0145 auto logger = Acts::getDefaultLogger("PHActsGSF", logLevel);
0146
0147 for (const auto& [key, track] : *m_trackMap)
0148 {
0149 auto pSurface = makePerigee(track);
0150 if (!pSurface)
0151 {
0152
0153 continue;
0154 }
0155 const auto seed = makeSeed(track, pSurface);
0156
0157 auto svtxseed = new TrackSeed_v2();
0158 std::map<TrkrDefs::cluskey, Acts::Vector3> clusterPositions;
0159 for (auto& cKey : get_cluster_keys(track))
0160 {
0161 auto cluster = m_clusterContainer->findCluster(cKey);
0162 auto globalPosition = m_tGeometry->getGlobalPosition(cKey, cluster);
0163 clusterPositions.insert(std::make_pair(cKey, globalPosition));
0164 svtxseed->insert_cluster_key(cKey);
0165 }
0166 svtxseed->set_phi(track->get_phi());
0167 TrackSeedHelper::circleFitByTaubin(svtxseed,clusterPositions, 0, 57);
0168 TrackSeedHelper::lineFit(svtxseed, clusterPositions, 7, 57);
0169
0170 ActsTrackFittingAlgorithm::MeasurementContainer measurements;
0171 TrackSeed* tpcseed = track->get_tpc_seed();
0172 TrackSeed* silseed = track->get_silicon_seed();
0173
0174
0175 if (!silseed or !tpcseed)
0176 {
0177 continue;
0178 }
0179
0180 auto crossing = silseed->get_crossing();
0181 if (crossing == SHRT_MAX)
0182 {
0183 continue;
0184 }
0185
0186
0187
0188
0189
0190
0191
0192 MakeSourceLinks makeSourceLinks;
0193 makeSourceLinks.setVerbosity(Verbosity());
0194 makeSourceLinks.set_pp_mode(m_pp_mode);
0195
0196 makeSourceLinks.resetTransientTransformMap(
0197 m_alignmentTransformationMapTransient,
0198 m_transient_id_set,
0199 m_tGeometry);
0200
0201
0202 auto sourceLinks = makeSourceLinks.getSourceLinks(
0203 tpcseed,
0204 measurements,
0205 m_clusterContainer,
0206 m_tGeometry,
0207 m_globalPositionWrapper,
0208 m_alignmentTransformationMapTransient,
0209 m_transient_id_set,
0210 crossing);
0211
0212
0213 auto silSourceLinks = makeSourceLinks.getSourceLinks(
0214 silseed,
0215 measurements,
0216 m_clusterContainer,
0217 m_tGeometry,
0218 m_globalPositionWrapper,
0219 m_alignmentTransformationMapTransient,
0220 m_transient_id_set,
0221 crossing);
0222
0223
0224 m_transient_geocontext = m_alignmentTransformationMapTransient;
0225
0226 for (auto& siSL : silSourceLinks)
0227 {
0228 sourceLinks.push_back(siSL);
0229 }
0230
0231 auto calibptr = std::make_unique<Calibrator>();
0232 CalibratorAdapter calibrator(*calibptr, measurements);
0233 auto magcontext = m_tGeometry->geometry().magFieldContext;
0234 auto calcontext = m_tGeometry->geometry().calibContext;
0235
0236 auto ppoptions = Acts::PropagatorPlainOptions();
0237
0238 ActsTrackFittingAlgorithm::GeneralFitterOptions options{
0239 m_transient_geocontext,
0240 magcontext,
0241 calcontext,
0242 &(*pSurface),
0243 ppoptions};
0244 if (Verbosity() > 2)
0245 {
0246 std::cout << "calling gsf with position "
0247 << seed.position(m_transient_geocontext).transpose()
0248 << " and momentum " << seed.momentum().transpose()
0249 << std::endl;
0250 }
0251 auto trackContainer = std::make_shared<Acts::VectorTrackContainer>();
0252 auto trackStateContainer = std::make_shared<Acts::VectorMultiTrajectory>();
0253 ActsTrackFittingAlgorithm::TrackContainer tracks(trackContainer, trackStateContainer);
0254 auto result = fitTrack(sourceLinks, seed, options, calibrator, tracks);
0255
0256 if (result.ok())
0257 {
0258 updateTrack(result, track, tracks, svtxseed, measurements);
0259 }
0260 }
0261
0262 return Fun4AllReturnCodes::EVENT_OK;
0263 }
0264
0265 std::shared_ptr<Acts::PerigeeSurface> PHActsGSF::makePerigee(SvtxTrack* track) const
0266 {
0267 SvtxVertex* vertex = m_vertexMap->get(track->get_vertex_id());
0268 if (!vertex)
0269 {
0270 return nullptr;
0271 }
0272
0273 Acts::Vector3 vertexpos(vertex->get_x() * Acts::UnitConstants::cm,
0274 vertex->get_y() * Acts::UnitConstants::cm,
0275 vertex->get_z() * Acts::UnitConstants::cm);
0276
0277 return Acts::Surface::makeShared<Acts::PerigeeSurface>(
0278 vertexpos);
0279 }
0280
0281 ActsTrackFittingAlgorithm::TrackParameters PHActsGSF::makeSeed(SvtxTrack* track,
0282 const std::shared_ptr<Acts::PerigeeSurface>& psurf) const
0283 {
0284 Acts::Vector4 fourpos(track->get_x() * Acts::UnitConstants::cm,
0285 track->get_y() * Acts::UnitConstants::cm,
0286 track->get_z() * Acts::UnitConstants::cm,
0287 10 * Acts::UnitConstants::ns);
0288
0289 int charge = track->get_charge();
0290 Acts::Vector3 momentum(track->get_px(),
0291 track->get_py(),
0292 track->get_pz());
0293
0294 ActsTransformations transformer;
0295 auto cov = transformer.rotateSvtxTrackCovToActs(track);
0296
0297 return ActsTrackFittingAlgorithm::TrackParameters::create(psurf,
0298 m_tGeometry->geometry().getGeoContext(),
0299 fourpos,
0300 momentum,
0301 charge / momentum.norm(),
0302 cov,
0303 Acts::ParticleHypothesis::electron(),
0304 1 * Acts::UnitConstants::cm)
0305 .value();
0306 }
0307
0308 ActsTrackFittingAlgorithm::TrackFitterResult PHActsGSF::fitTrack(
0309 const std::vector<Acts::SourceLink>& sourceLinks,
0310 const ActsTrackFittingAlgorithm::TrackParameters& seed,
0311 const ActsTrackFittingAlgorithm::GeneralFitterOptions& options,
0312 const CalibratorAdapter& calibrator,
0313 ActsTrackFittingAlgorithm::TrackContainer& tracks)
0314 {
0315 return (*m_fitCfg.fit)(sourceLinks, seed, options, calibrator, tracks);
0316 }
0317
0318 void PHActsGSF::updateTrack(FitResult& result, SvtxTrack* track,
0319 ActsTrackFittingAlgorithm::TrackContainer& tracks,
0320 const TrackSeed* seed,
0321 const ActsTrackFittingAlgorithm::MeasurementContainer& measurements)
0322 {
0323 std::vector<Acts::MultiTrajectoryTraits::IndexType> trackTips;
0324 trackTips.reserve(1);
0325 auto& outtrack = result.value();
0326 trackTips.emplace_back(outtrack.tipIndex());
0327 ActsExamples::Trajectories::IndexedParameters indexedParams;
0328
0329 indexedParams.emplace(std::pair{outtrack.tipIndex(),
0330 ActsExamples::TrackParameters{outtrack.referenceSurface().getSharedPtr(),
0331 outtrack.parameters(), outtrack.covariance(), Acts::ParticleHypothesis::electron()}});
0332
0333 updateSvtxTrack(trackTips, indexedParams, tracks, track);
0334
0335 if (m_actsEvaluator)
0336 {
0337 m_evaluator->evaluateTrackFit(tracks, trackTips, indexedParams, track,
0338 seed, measurements);
0339 }
0340 }
0341
0342 void PHActsGSF::updateSvtxTrack(std::vector<Acts::MultiTrajectoryTraits::IndexType>& tips,
0343 Trajectory::IndexedParameters& paramsMap,
0344 ActsTrackFittingAlgorithm::TrackContainer& tracks,
0345 SvtxTrack* track)
0346 {
0347 const auto& mj = tracks.trackStateContainer();
0348 const auto& tracktip = tips.front();
0349 const auto& params = paramsMap.find(tracktip)->second;
0350 const auto trajState =
0351 Acts::MultiTrajectoryHelpers::trajectoryState(mj, tracktip);
0352
0353 if (Verbosity() > 1)
0354 {
0355 std::cout << "Old track parameters: " << std::endl
0356 << " (" << track->get_x()
0357 << ", " << track->get_y() << ", " << track->get_z()
0358 << ")" << std::endl
0359 << " (" << track->get_px() << ", " << track->get_py()
0360 << ", " << track->get_pz() << ")" << std::endl;
0361 std::cout << "New GSF track parameters: " << std::endl
0362 << " " << params.position(m_transient_geocontext).transpose()
0363 << std::endl
0364 << " " << params.momentum().transpose()
0365 << std::endl;
0366 }
0367
0368
0369 track->clear_states();
0370
0371
0372
0373 float pathlength = 0.0;
0374 SvtxTrackState_v1 out(pathlength);
0375 out.set_x(0.0);
0376 out.set_y(0.0);
0377 out.set_z(0.0);
0378 track->insert_state(&out);
0379
0380 track->set_x(params.position(m_transient_geocontext)(0) / Acts::UnitConstants::cm);
0381 track->set_y(params.position(m_transient_geocontext)(1) / Acts::UnitConstants::cm);
0382 track->set_z(params.position(m_transient_geocontext)(2) / Acts::UnitConstants::cm);
0383
0384 track->set_px(params.momentum()(0));
0385 track->set_py(params.momentum()(1));
0386 track->set_pz(params.momentum()(2));
0387 track->set_charge(params.charge());
0388 track->set_chisq(trajState.chi2Sum);
0389 track->set_ndf(trajState.NDF);
0390
0391 ActsTransformations transformer;
0392 transformer.setVerbosity(Verbosity());
0393
0394 if (params.covariance())
0395 {
0396 auto rotatedCov = transformer.rotateActsCovToSvtxTrack(params);
0397 for (int i = 0; i < 6; i++)
0398 {
0399 for (int j = 0; j < 6; j++)
0400 {
0401 track->set_error(i, j, rotatedCov(i, j));
0402 }
0403 }
0404 }
0405
0406 transformer.fillSvtxTrackStates(mj, tracktip, track, m_transient_geocontext);
0407 }
0408
0409
0410 std::vector<TrkrDefs::cluskey> PHActsGSF::get_cluster_keys(SvtxTrack* track)
0411 {
0412 std::vector<TrkrDefs::cluskey> out;
0413 for (const auto& seed : {track->get_silicon_seed(), track->get_tpc_seed()})
0414 {
0415 if (seed)
0416 {
0417 std::copy(seed->begin_cluster_keys(), seed->end_cluster_keys(), std::back_inserter(out));
0418 }
0419 }
0420
0421 return out;
0422 }
0423
0424
0425 int PHActsGSF::End(PHCompositeNode* )
0426 {
0427 if (m_actsEvaluator)
0428 {
0429 m_evaluator->End();
0430 }
0431 return Fun4AllReturnCodes::EVENT_OK;
0432 }
0433
0434
0435 int PHActsGSF::getNodes(PHCompositeNode* topNode)
0436 {
0437
0438 m_trackMap = findNode::getClass<SvtxTrackMap>(topNode, m_trackMapName);
0439 if (!m_trackMap)
0440 {
0441 std::cout << PHWHERE << " The input track map is not available. Exiting PHActsGSF" << std::endl;
0442 return Fun4AllReturnCodes::ABORTEVENT;
0443 }
0444
0445
0446 m_clusterContainer = findNode::getClass<TrkrClusterContainer>(topNode, "TRKR_CLUSTER");
0447 if (!m_clusterContainer)
0448 {
0449 std::cout << PHWHERE << "The input cluster container is not available. Exiting PHActsGSF" << std::endl;
0450 return Fun4AllReturnCodes::ABORTEVENT;
0451 }
0452
0453
0454 m_tGeometry = findNode::getClass<ActsGeometry>(topNode, "ActsGeometry");
0455 if (!m_tGeometry)
0456 {
0457 std::cout << PHWHERE << "The input Acts tracking geometry is not available. Exiting PHActsGSF" << std::endl;
0458 return Fun4AllReturnCodes::ABORTEVENT;
0459 }
0460
0461
0462 m_globalPositionWrapper.loadNodes(topNode);
0463
0464
0465 m_vertexMap = findNode::getClass<SvtxVertexMap>(topNode, "SvtxVertexMap");
0466 if (!m_vertexMap)
0467 {
0468 std::cout << PHWHERE << "Vertex map unavailable, exiting PHActsGSF" << std::endl;
0469 return Fun4AllReturnCodes::ABORTEVENT;
0470 }
0471
0472 m_alignmentTransformationMapTransient = findNode::getClass<alignmentTransformationContainer>(topNode, "alignmentTransformationContainerTransient");
0473 if (!m_alignmentTransformationMapTransient)
0474 {
0475 std::cout << PHWHERE << "alignmentTransformationContainerTransient not on node tree. Bailing"
0476 << std::endl;
0477 return Fun4AllReturnCodes::ABORTEVENT;
0478 }
0479
0480 return Fun4AllReturnCodes::EVENT_OK;
0481 }