Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-17 09:21:03

0001 /*!
0002  *  \file       PHGenFitTrkFitter.C
0003  *  \brief      Refit SvtxTracks with PHGenFit.
0004  *  \details    Refit SvtxTracks with PHGenFit.
0005  *  \author     Haiwang Yu <yuhw@nmsu.edu>
0006  */
0007 
0008 #include "PHGenFitTrkFitter.h"
0009 
0010 #include <fun4all/Fun4AllReturnCodes.h>
0011 #include <fun4all/PHTFileServer.h>
0012 #include <fun4all/SubsysReco.h>  // for SubsysReco
0013 
0014 #include <g4detectors/PHG4CylinderGeom.h>  // for PHG4CylinderGeom
0015 #include <g4detectors/PHG4CylinderGeomContainer.h>
0016 
0017 #include <intt/CylinderGeomIntt.h>
0018 #include <intt/CylinderGeomInttHelper.h>
0019 
0020 #include <micromegas/CylinderGeomMicromegas.h>
0021 #include <micromegas/MicromegasDefs.h>
0022 
0023 #include <mvtx/CylinderGeom_Mvtx.h>
0024 #include <mvtx/CylinderGeom_MvtxHelper.h>
0025 
0026 #include <phfield/PHFieldUtility.h>
0027 
0028 #include <phgenfit/Fitter.h>
0029 #include <phgenfit/Measurement.h>  // for Measurement
0030 #include <phgenfit/PlanarMeasurement.h>
0031 #include <phgenfit/SpacepointMeasurement.h>
0032 #include <phgenfit/Track.h>
0033 
0034 #include <phgeom/PHGeomUtility.h>
0035 
0036 #include <phool/PHCompositeNode.h>
0037 #include <phool/PHIODataNode.h>
0038 #include <phool/PHNode.h>  // for PHNode
0039 #include <phool/PHNodeIterator.h>
0040 #include <phool/PHObject.h>  // for PHObject
0041 #include <phool/getClass.h>
0042 #include <phool/phool.h>
0043 
0044 #include <trackbase/ActsGeometry.h>
0045 #include <trackbase/InttDefs.h>
0046 #include <trackbase/MvtxDefs.h>
0047 #include <trackbase/TpcDefs.h>
0048 #include <trackbase/TrkrCluster.h>  // for TrkrCluster
0049 #include <trackbase/TrkrClusterContainer.h>
0050 #include <trackbase/TrkrDefs.h>
0051 
0052 #include <trackbase_historic/SvtxTrack.h>
0053 #include <trackbase_historic/SvtxTrackMap.h>
0054 #include <trackbase_historic/SvtxTrackMap_v2.h>
0055 #include <trackbase_historic/SvtxTrackState.h>  // for SvtxTrackState
0056 #include <trackbase_historic/SvtxTrackState_v2.h>
0057 #include <trackbase_historic/SvtxTrack_v4.h>
0058 #include <trackbase_historic/TrackSeed.h>
0059 #include <trackbase_historic/TrackSeedContainer.h>
0060 #include <trackbase_historic/TrackSeedHelper.h>
0061 
0062 #include <GenFit/AbsMeasurement.h>  // for AbsMeasurement
0063 #include <GenFit/Exception.h>       // for Exception
0064 #include <GenFit/KalmanFitterInfo.h>
0065 #include <GenFit/MeasuredStateOnPlane.h>
0066 #include <GenFit/RKTrackRep.h>
0067 #include <GenFit/Track.h>
0068 #include <GenFit/TrackPoint.h>  // for TrackPoint
0069 
0070 #include <TMatrixDSymfwd.h>  // for TMatrixDSym
0071 #include <TMatrixFfwd.h>     // for TMatrixF
0072 #include <TMatrixT.h>        // for TMatrixT, operator*
0073 #include <TMatrixTSym.h>     // for TMatrixTSym
0074 #include <TMatrixTUtils.h>   // for TMatrixTRow
0075 #include <TRotation.h>
0076 #include <TTree.h>
0077 #include <TVector3.h>
0078 #include <TVectorDfwd.h>  // for TVectorD
0079 #include <TVectorT.h>     // for TVectorT
0080 
0081 #include <cmath>  // for sqrt, NAN
0082 #include <iostream>
0083 #include <map>
0084 #include <memory>
0085 #include <utility>
0086 #include <vector>
0087 
0088 class TGeoManager;
0089 namespace genfit
0090 {
0091   class AbsTrackRep;
0092 }
0093 
0094 #define LogDebug(exp) std::cout << "DEBUG: " << __FILE__ << ": " << __LINE__ << ": " << (exp) << std::endl
0095 #define LogError(exp) std::cout << "ERROR: " << __FILE__ << ": " << __LINE__ << ": " << (exp) << std::endl
0096 #define LogWarning(exp) std::cout << "WARNING: " << __FILE__ << ": " << __LINE__ << ": " << (exp) << std::endl
0097 
0098 using namespace std;
0099 
0100 //______________________________________________________
0101 namespace
0102 {
0103 
0104   // square
0105   template <class T>
0106   inline static constexpr T square(const T& x)
0107   {
0108     return x * x;
0109   }
0110 
0111   // square
0112   template <class T>
0113   inline static T get_r(const T& x, const T& y)
0114   {
0115     return std::sqrt( square(x)+square(y));
0116   }
0117 
0118   // convert gf state to SvtxTrackState_v2
0119   SvtxTrackState_v2 create_track_state(float pathlength, const genfit::MeasuredStateOnPlane* gf_state)
0120   {
0121     SvtxTrackState_v2 out(pathlength);
0122     out.set_x(gf_state->getPos().x());
0123     out.set_y(gf_state->getPos().y());
0124     out.set_z(gf_state->getPos().z());
0125 
0126     out.set_px(gf_state->getMom().x());
0127     out.set_py(gf_state->getMom().y());
0128     out.set_pz(gf_state->getMom().z());
0129 
0130     for (int i = 0; i < 6; i++)
0131     {
0132       for (int j = i; j < 6; j++)
0133       {
0134         out.set_error(i, j, gf_state->get6DCov()[i][j]);
0135       }
0136     }
0137 
0138     return out;
0139   }
0140 
0141   // get cluster keys from a given track
0142   std::vector<TrkrDefs::cluskey> get_cluster_keys(const SvtxTrack* track)
0143   {
0144     std::vector<TrkrDefs::cluskey> out;
0145     for (const auto& seed : {track->get_silicon_seed(), track->get_tpc_seed()})
0146     {
0147       if (seed)
0148       {
0149         std::copy(seed->begin_cluster_keys(), seed->end_cluster_keys(), std::back_inserter(out));
0150       }
0151     }
0152     return out;
0153   }
0154 
0155   [[maybe_unused]] std::ostream& operator<<(std::ostream& out, const Acts::Vector3& vector)
0156   {
0157     out << "(" << vector.x() << ", " << vector.y() << ", " << vector.z() << ")";
0158     return out;
0159   }
0160 
0161   TVector3 get_world_from_local_vect( ActsGeometry* geometry, Surface surface, const TVector3& local_vect )
0162   {
0163 
0164     // get global vector from local, using ACTS surface
0165     Acts::Vector3 local(
0166       local_vect.x()*Acts::UnitConstants::cm,
0167       local_vect.y()*Acts::UnitConstants::cm,
0168       local_vect.z()*Acts::UnitConstants::cm );
0169 
0170     // TODO: check signification of the last two parameters to referenceFrame.
0171     const Acts::Vector3 global = surface->referenceFrame(geometry->geometry().getGeoContext(), {0,0,0}, {0,0,0})*local;
0172     return TVector3(
0173       global.x()/Acts::UnitConstants::cm,
0174       global.y()/Acts::UnitConstants::cm,
0175       global.z()/Acts::UnitConstants::cm );
0176   }
0177 
0178 }  // namespace
0179 
0180 /*
0181  * Constructor
0182  */
0183 PHGenFitTrkFitter::PHGenFitTrkFitter(const string& name)
0184   : SubsysReco(name)
0185 {
0186 }
0187 
0188 /*
0189  * Init
0190  */
0191 int PHGenFitTrkFitter::Init(PHCompositeNode* /*topNode*/)
0192 {
0193   return Fun4AllReturnCodes::EVENT_OK;
0194 }
0195 
0196 /*
0197  * Init run
0198  */
0199 int PHGenFitTrkFitter::InitRun(PHCompositeNode* topNode)
0200 {
0201   CreateNodes(topNode);
0202 
0203   auto tgeo_manager = PHGeomUtility::GetTGeoManager(topNode);
0204   auto field = PHFieldUtility::GetFieldMapNode(nullptr, topNode);
0205 
0206   _fitter.reset(PHGenFit::Fitter::getInstance(tgeo_manager, field, _track_fitting_alg_name, "RKTrackRep", false));
0207   _fitter->set_verbosity(Verbosity());
0208 
0209   std::cout << "PHGenFitTrkFitter::InitRun - m_fit_silicon_mms: " << m_fit_silicon_mms << std::endl;
0210   std::cout << "PHGenFitTrkFitter::InitRun - m_use_micromegas: " << m_use_micromegas << std::endl;
0211 
0212   // print disabled layers
0213   // if( Verbosity() )
0214   {
0215     for (const auto& layer : _disabled_layers)
0216     {
0217       std::cout << PHWHERE << " Layer " << layer << " is disabled." << std::endl;
0218     }
0219   }
0220 
0221   return Fun4AllReturnCodes::EVENT_OK;
0222 }
0223 
0224 /*
0225  * process_event():
0226  *  Call user instructions for every event.
0227  *  This function contains the analysis structure.
0228  *
0229  */
0230 int PHGenFitTrkFitter::process_event(PHCompositeNode* topNode)
0231 {
0232   ++_event;
0233 
0234   if (Verbosity() > 1)
0235   {
0236     std::cout << PHWHERE << "Events processed: " << _event << std::endl;
0237   }
0238 
0239   // clear global position map
0240   GetNodes(topNode);
0241 
0242   // clear default track map, fill with seeds
0243   m_trackMap->Reset();
0244 
0245   unsigned int trackid = 0;
0246   for (const auto& track : *m_seedMap)
0247   {
0248     if (!track) continue;
0249 
0250     // get silicon seed and check
0251     const auto siid = track->get_silicon_seed_index();
0252     if (siid == std::numeric_limits<unsigned int>::max()) continue;
0253     const auto siseed = m_siliconSeeds->get(siid);
0254     if (!siseed) continue;
0255 
0256     // get crossing number and check
0257     const auto crossing = siseed->get_crossing();
0258     if (crossing == SHRT_MAX) continue;
0259 
0260     // get tpc seed and check
0261     const auto tpcid = track->get_tpc_seed_index();
0262     const auto tpcseed = m_tpcSeeds->get(tpcid);
0263     if (!tpcseed) continue;
0264 
0265     // build track
0266     auto svtxtrack = std::make_unique<SvtxTrack_v4>();
0267     svtxtrack->set_id(trackid++);
0268     svtxtrack->set_silicon_seed(siseed);
0269     svtxtrack->set_tpc_seed(tpcseed);
0270     svtxtrack->set_crossing(crossing);
0271 
0272     // track position comes from silicon seed
0273     const auto position = TrackSeedHelper::get_xyz(siseed);
0274     svtxtrack->set_x(position.x());
0275     svtxtrack->set_y(position.y());
0276     svtxtrack->set_z(position.z());
0277 
0278     // track momentum comes from tpc seed
0279     svtxtrack->set_charge(tpcseed->get_qOverR() > 0 ? 1 : -1);
0280     svtxtrack->set_px(tpcseed->get_px());
0281     svtxtrack->set_py(tpcseed->get_py());
0282     svtxtrack->set_pz(tpcseed->get_pz());
0283 
0284     // insert in map
0285     m_trackMap->insert(svtxtrack.get());
0286   }
0287 
0288   // stands for Refit_GenFit_Tracks
0289   vector<genfit::Track*> rf_gf_tracks;
0290   vector<std::shared_ptr<PHGenFit::Track> > rf_phgf_tracks;
0291 
0292   map<unsigned int, unsigned int> svtxtrack_genfittrack_map;
0293 
0294   for (const auto& [key, svtx_track] : *m_trackMap)
0295   {
0296     if (!svtx_track) continue;
0297 
0298     if (Verbosity() > 10)
0299     {
0300       cout << "   process SVTXTrack " << key << endl;
0301       svtx_track->identify();
0302     }
0303 
0304     if (!(svtx_track->get_pt() > _fit_min_pT)) continue;
0305 
0306     // This is the final track (re)fit. It does not include the collision vertex. If fit_primary_track is set, a refit including the vertex is done below.
0307     // rf_phgf_track stands for Refit_PHGenFit_Track
0308     const auto rf_phgf_track = ReFitTrack(topNode, svtx_track);
0309     if (rf_phgf_track)
0310     {
0311       svtxtrack_genfittrack_map[svtx_track->get_id()] = rf_phgf_tracks.size();
0312       rf_phgf_tracks.push_back(rf_phgf_track);
0313       if (rf_phgf_track->get_ndf() > _vertex_min_ndf)
0314       {
0315         rf_gf_tracks.push_back(rf_phgf_track->getGenFitTrack());
0316       }
0317 
0318       if (Verbosity() > 10) cout << "Done refitting input track" << svtx_track->get_id() << " or rf_phgf_track " << rf_phgf_tracks.size() << endl;
0319     }
0320     else if (Verbosity() >= 1)
0321     {
0322       cout << "failed refitting input track# " << key << endl;
0323     }
0324   }
0325 
0326   // Finds the refitted rf_phgf_track corresponding to each SvtxTrackMap entry
0327   // Converts it to an SvtxTrack in MakeSvtxTrack
0328   // MakeSvtxTrack takes a vertex that it gets from the map made in FillSvtxVertex
0329   // If the refit was succesful, the track on the node tree is replaced with the new one
0330   // If not, the track is erased from the node tree
0331   for (SvtxTrackMap::Iter iter = m_trackMap->begin(); iter != m_trackMap->end();)
0332   {
0333     std::shared_ptr<PHGenFit::Track> rf_phgf_track;
0334 
0335     // find the genfit track that corresponds to this one on the node tree
0336     unsigned int itrack = 0;
0337     if (svtxtrack_genfittrack_map.find(iter->second->get_id()) != svtxtrack_genfittrack_map.end())
0338     {
0339       itrack = svtxtrack_genfittrack_map[iter->second->get_id()];
0340       rf_phgf_track = rf_phgf_tracks[itrack];
0341     }
0342 
0343     if (rf_phgf_track)
0344     {
0345       const auto rf_track = MakeSvtxTrack(iter->second, rf_phgf_track);
0346       if (rf_track)
0347       {
0348         // replace track in map
0349         iter->second->CopyFrom(rf_track.get());
0350       }
0351       else
0352       {
0353         // converting track failed. erase track from map
0354         auto key = iter->first;
0355         ++iter;
0356         m_trackMap->erase(key);
0357         continue;
0358       }
0359     }
0360     else
0361     {
0362       // genfit track is invalid. erase track from map
0363       auto key = iter->first;
0364       ++iter;
0365       m_trackMap->erase(key);
0366       continue;
0367     }
0368 
0369     ++iter;
0370   }
0371 
0372   // clear genfit tracks
0373   rf_phgf_tracks.clear();
0374 
0375   return Fun4AllReturnCodes::EVENT_OK;
0376 }
0377 
0378 /*
0379  * End
0380  */
0381 int PHGenFitTrkFitter::End(PHCompositeNode* /*topNode*/)
0382 {
0383   return Fun4AllReturnCodes::EVENT_OK;
0384 }
0385 
0386 int PHGenFitTrkFitter::CreateNodes(PHCompositeNode* topNode)
0387 {
0388   // create nodes...
0389   PHNodeIterator iter(topNode);
0390 
0391   auto dstNode = static_cast<PHCompositeNode*>(iter.findFirst("PHCompositeNode", "DST"));
0392   if (!dstNode)
0393   {
0394     cerr << PHWHERE << "DST Node missing, doing nothing." << endl;
0395     return Fun4AllReturnCodes::ABORTEVENT;
0396   }
0397   PHNodeIterator iter_dst(dstNode);
0398 
0399   // Create the SVTX node
0400   auto svtx_node = dynamic_cast<PHCompositeNode*>(iter_dst.findFirst("PHCompositeNode", "SVTX"));
0401   if (!svtx_node)
0402   {
0403     svtx_node = new PHCompositeNode("SVTX");
0404     dstNode->addNode(svtx_node);
0405     if (Verbosity())
0406     {
0407       cout << "SVTX node added" << endl;
0408     }
0409   }
0410 
0411   // default track map
0412   m_trackMap = findNode::getClass<SvtxTrackMap>(topNode, _trackMap_name);
0413   if (!m_trackMap)
0414   {
0415     m_trackMap = new SvtxTrackMap_v2;
0416     auto node = new PHIODataNode<PHObject>(m_trackMap, _trackMap_name, "PHObject");
0417     svtx_node->addNode(node);
0418   }
0419 
0420   return Fun4AllReturnCodes::EVENT_OK;
0421 }
0422 
0423 //______________________________________________________
0424 void PHGenFitTrkFitter::disable_layer(int layer, bool disabled)
0425 {
0426   if (disabled)
0427     _disabled_layers.insert(layer);
0428   else
0429     _disabled_layers.erase(layer);
0430 }
0431 
0432 //______________________________________________________
0433 void PHGenFitTrkFitter::set_disabled_layers(const std::set<int>& layers)
0434 {
0435   _disabled_layers = layers;
0436 }
0437 
0438 //______________________________________________________
0439 void PHGenFitTrkFitter::clear_disabled_layers()
0440 {
0441   _disabled_layers.clear();
0442 }
0443 
0444 //______________________________________________________
0445 const std::set<int>& PHGenFitTrkFitter::get_disabled_layers() const
0446 {
0447   return _disabled_layers;
0448 }
0449 
0450 //______________________________________________________
0451 void PHGenFitTrkFitter::set_fit_silicon_mms(bool value)
0452 {
0453   // store flags
0454   m_fit_silicon_mms = value;
0455 
0456   // disable/enable layers accordingly
0457   for (int layer = 7; layer < 23; ++layer)
0458   {
0459     disable_layer(layer, value);
0460   }
0461   for (int layer = 23; layer < 39; ++layer)
0462   {
0463     disable_layer(layer, value);
0464   }
0465   for (int layer = 39; layer < 55; ++layer)
0466   {
0467     disable_layer(layer, value);
0468   }
0469 }
0470 
0471 /*
0472  * GetNodes():
0473  *  Get all the all the required nodes off the node tree
0474  */
0475 int PHGenFitTrkFitter::GetNodes(PHCompositeNode* topNode)
0476 {
0477   // acts geometry
0478   m_tgeometry = findNode::getClass<ActsGeometry>(topNode, "ActsGeometry");
0479   if (!m_tgeometry)
0480   {
0481     std::cout << "PHGenFitTrkFitter::GetNodes - No acts tracking geometry, can't proceed" << std::endl;
0482     return Fun4AllReturnCodes::ABORTEVENT;
0483   }
0484 
0485   // DST objects
0486   // clusters
0487   m_clustermap = findNode::getClass<TrkrClusterContainer>(topNode, "CORRECTED_TRKR_CLUSTER");
0488   if (m_clustermap)
0489   {
0490     if (_event < 2)
0491     {
0492       std::cout << "PHGenFitTrkFitter::GetNodes - Using CORRECTED_TRKR_CLUSTER node " << std::endl;
0493     }
0494   }
0495   else
0496   {
0497     if (_event < 2)
0498     {
0499       std::cout << "PHGenFitTrkFitter::GetNodes - CORRECTED_TRKR_CLUSTER node not found, using TRKR_CLUSTER" << std::endl;
0500     }
0501     m_clustermap = findNode::getClass<TrkrClusterContainer>(topNode, "TRKR_CLUSTER");
0502   }
0503 
0504   if (!m_clustermap)
0505   {
0506     cout << PHWHERE << "PHGenFitTrkFitter::GetNodes - TRKR_CLUSTER node not found on node tree" << endl;
0507     return Fun4AllReturnCodes::ABORTEVENT;
0508   }
0509 
0510   // seeds
0511   m_seedMap = findNode::getClass<TrackSeedContainer>(topNode, _seedMap_name);
0512   if (!m_seedMap)
0513   {
0514     std::cout << "PHGenFitTrkFitter::GetNodes - No Svtx seed map on node tree. Exiting." << std::endl;
0515     return Fun4AllReturnCodes::ABORTEVENT;
0516   }
0517 
0518   m_tpcSeeds = findNode::getClass<TrackSeedContainer>(topNode, "TpcTrackSeedContainer");
0519   if (!m_tpcSeeds)
0520   {
0521     std::cout << "PHGenFitTrkFitter::GetNodes - TpcTrackSeedContainer not on node tree. Bailing"
0522               << std::endl;
0523     return Fun4AllReturnCodes::ABORTEVENT;
0524   }
0525 
0526   m_siliconSeeds = findNode::getClass<TrackSeedContainer>(topNode, "SiliconTrackSeedContainer");
0527   if (!m_siliconSeeds)
0528   {
0529     std::cout << "PHGenFitTrkFitter::GetNodes - SiliconTrackSeedContainer not on node tree. Bailing"
0530               << std::endl;
0531     return Fun4AllReturnCodes::ABORTEVENT;
0532   }
0533 
0534   // Svtx Tracks
0535   m_trackMap = findNode::getClass<SvtxTrackMap>(topNode, _trackMap_name);
0536   if (!m_trackMap && _event < 2)
0537   {
0538     cout << "PHGenFitTrkFitter::GetNodes - SvtxTrackMap node not found on node tree" << endl;
0539     return Fun4AllReturnCodes::ABORTEVENT;
0540   }
0541 
0542   // global position wrapper
0543   m_globalPositionWrapper.loadNodes(topNode);
0544   if (m_disable_module_edge_corr) { m_globalPositionWrapper.set_enable_module_edge_corr(false); }
0545   if (m_disable_static_corr) { m_globalPositionWrapper.set_enable_static_corr(false); }
0546   if (m_disable_average_corr) { m_globalPositionWrapper.set_enable_average_corr(false); }
0547   if (m_disable_fluctuation_corr) { m_globalPositionWrapper.set_enable_fluctuation_corr(false); }
0548 
0549   return Fun4AllReturnCodes::EVENT_OK;
0550 }
0551 
0552 /*
0553  * fit track with SvtxTrack as input seed.
0554  * \param intrack Input SvtxTrack
0555  */
0556 std::shared_ptr<PHGenFit::Track> PHGenFitTrkFitter::ReFitTrack(PHCompositeNode* /*topNode*/, const SvtxTrack* intrack)
0557 {
0558   // std::shared_ptr<PHGenFit::Track> empty_track(nullptr);
0559   if (!intrack)
0560   {
0561     cerr << PHWHERE << " Input SvtxTrack is nullptr!" << endl;
0562     return nullptr;
0563   }
0564 
0565   // get crossing from track
0566   const auto crossing = intrack->get_crossing();
0567   assert(crossing != SHRT_MAX);
0568 
0569   // prepare seed
0570   TVector3 seed_mom(100, 0, 0);
0571   TVector3 seed_pos(0, 0, 0);
0572   TMatrixDSym seed_cov(6);
0573   for (int i = 0; i < 6; i++)
0574   {
0575     for (int j = 0; j < 6; j++)
0576     {
0577       seed_cov[i][j] = 100.;
0578     }
0579   }
0580 
0581   // Create measurements
0582   std::vector<PHGenFit::Measurement*> measurements;
0583 
0584   // sort clusters with radius before fitting
0585   if (Verbosity() > 10)
0586   {
0587     intrack->identify();
0588   }
0589   std::map<float, TrkrDefs::cluskey> m_r_cluster_id;
0590 
0591   unsigned int n_silicon_clusters = 0;
0592   unsigned int n_micromegas_clusters = 0;
0593 
0594   for (const auto& cluster_key : get_cluster_keys(intrack))
0595   {
0596     // count clusters
0597     switch (TrkrDefs::getTrkrId(cluster_key))
0598     {
0599     case TrkrDefs::mvtxId:
0600     case TrkrDefs::inttId:
0601       ++n_silicon_clusters;
0602       break;
0603 
0604     case TrkrDefs::micromegasId:
0605       ++n_micromegas_clusters;
0606       break;
0607 
0608     default:
0609       break;
0610     }
0611 
0612     const auto cluster = m_clustermap->findCluster(cluster_key);
0613     const auto globalPosition = m_globalPositionWrapper.getGlobalPositionDistortionCorrected(cluster_key, cluster, crossing);
0614     const float r = get_r(globalPosition.x(), globalPosition.y());
0615     m_r_cluster_id.emplace(r, cluster_key);
0616     if (Verbosity() > 10)
0617     {
0618       const int layer_out = TrkrDefs::getLayer(cluster_key);
0619       cout << "    Layer " << layer_out << " cluster " << cluster_key << " radius " << r << endl;
0620     }
0621   }
0622 
0623   // discard track if not enough clusters when fitting with silicon + mm only
0624   if (m_fit_silicon_mms)
0625   {
0626     if (n_silicon_clusters == 0)
0627     {
0628       return nullptr;
0629     }
0630     if (m_use_micromegas && n_micromegas_clusters == 0)
0631     {
0632       return nullptr;
0633     }
0634   }
0635 
0636   for (const auto& [r, cluster_key] : m_r_cluster_id)
0637   {
0638     const int layer = TrkrDefs::getLayer(cluster_key);
0639 
0640     // skip disabled layers
0641     if (_disabled_layers.find(layer) != _disabled_layers.end())
0642     {
0643       continue;
0644     }
0645 
0646     auto cluster = m_clustermap->findCluster(cluster_key);
0647     if (!cluster)
0648     {
0649       LogError("No cluster Found!");
0650       continue;
0651     }
0652 
0653     const auto globalPosition_acts = m_globalPositionWrapper.getGlobalPositionDistortionCorrected(cluster_key, cluster, crossing);
0654     const TVector3 pos(globalPosition_acts.x(), globalPosition_acts.y(), globalPosition_acts.z());
0655 
0656     const double cluster_rphi_error = cluster->getRPhiError();
0657     const double cluster_z_error = cluster->getZError();
0658 
0659     seed_mom.SetPhi(pos.Phi());
0660     seed_mom.SetTheta(pos.Theta());
0661 
0662     std::unique_ptr<PHGenFit::PlanarMeasurement> meas;
0663     switch (TrkrDefs::getTrkrId(cluster_key))
0664         {
0665         case TrkrDefs::mvtxId:
0666         {
0667           auto hitsetkey = TrkrDefs::getHitSetKeyFromClusKey(cluster_key);
0668           auto surface = m_tgeometry->maps().getSiliconSurface(hitsetkey);
0669         const auto u = get_world_from_local_vect(m_tgeometry, surface, {1, 0, 0});
0670         const auto v = get_world_from_local_vect(m_tgeometry, surface, {0, 1, 0});
0671           meas.reset( new PHGenFit::PlanarMeasurement(pos, u, v, cluster_rphi_error, cluster_z_error) );
0672 
0673           break;
0674         }
0675 
0676         case TrkrDefs::inttId:
0677         {
0678           auto hitsetkey = TrkrDefs::getHitSetKeyFromClusKey(cluster_key);
0679           auto surface = m_tgeometry->maps().getSiliconSurface(hitsetkey);
0680         const auto u = get_world_from_local_vect(m_tgeometry, surface, {1, 0, 0});
0681         const auto v = get_world_from_local_vect(m_tgeometry, surface, {0, 1, 0});
0682           meas.reset( new PHGenFit::PlanarMeasurement(pos, u, v, cluster_rphi_error, cluster_z_error) );
0683           break;
0684         }
0685 
0686         case TrkrDefs::micromegasId:
0687         {
0688 
0689           // get geometry
0690           /* a situation where micromegas clusters are found, but not the geometry, should not happen */
0691           auto hitsetkey = TrkrDefs::getHitSetKeyFromClusKey(cluster_key);
0692           auto surface = m_tgeometry->maps().getMMSurface(hitsetkey);
0693         const auto u = get_world_from_local_vect(m_tgeometry, surface, {1, 0, 0});
0694         const auto v = get_world_from_local_vect(m_tgeometry, surface, {0, 1, 0});
0695           meas.reset( new PHGenFit::PlanarMeasurement(pos, u, v, cluster_rphi_error, cluster_z_error) );
0696           break;
0697         }
0698 
0699         case TrkrDefs::tpcId:
0700         {
0701           // create measurement
0702           const TVector3 n(globalPosition_acts.x(), globalPosition_acts.y(), 0);
0703           meas.reset( new PHGenFit::PlanarMeasurement(pos, n, cluster_rphi_error, cluster_z_error) );
0704           break;
0705         }
0706 
0707     }
0708 
0709     // assign cluster key to measurement
0710     meas->set_cluster_key( cluster_key );
0711 
0712     // add to list
0713     measurements.push_back(meas.release());
0714   }
0715 
0716   /*!
0717    * mu+:   -13
0718    * mu-:   13
0719    * pi+:   211
0720    * pi-:   -211
0721    * e-:    11
0722    * e+:    -11
0723    */
0724   // TODO Add multiple TrackRep choices.
0725   // int pid = 211;
0726   auto rep = new genfit::RKTrackRep(_primary_pid_guess);
0727   std::shared_ptr<PHGenFit::Track> track(new PHGenFit::Track(rep, seed_pos, seed_mom, seed_cov));
0728 
0729   // TODO unsorted measurements, should use sorted ones?
0730   track->addMeasurements(measurements);
0731 
0732   /*!
0733    *  Fit the track
0734    *  ret code 0 means 0 error or good status
0735    */
0736 
0737   if (_fitter->processTrack(track.get(), false) != 0)
0738   {
0739     // if (Verbosity() >= 1)
0740     {
0741       LogWarning("Track fitting failed");
0742     }
0743     // delete track;
0744     return nullptr;
0745   }
0746 
0747   if (Verbosity() > 10)
0748     cout << " track->getChisq() " << track->get_chi2() << " get_ndf " << track->get_ndf()
0749          << " mom.X " << track->get_mom().X()
0750          << " mom.Y " << track->get_mom().Y()
0751          << " mom.Z " << track->get_mom().Z()
0752          << endl;
0753 
0754   return track;
0755 }
0756 
0757 /*
0758  * Make SvtxTrack from PHGenFit::Track and SvtxTrack
0759  */
0760 // SvtxTrack* PHGenFitTrkFitter::MakeSvtxTrack(const SvtxTrack* svtx_track,
0761 std::shared_ptr<SvtxTrack> PHGenFitTrkFitter::MakeSvtxTrack(const SvtxTrack* svtx_track, const std::shared_ptr<PHGenFit::Track>& phgf_track)
0762 {
0763   double chi2 = phgf_track->get_chi2();
0764   double ndf = phgf_track->get_ndf();
0765 
0766   TVector3 vertex_position(0, 0, 0);
0767   TMatrixF vertex_cov(3, 3);
0768   double dvr2 = 0;
0769   double dvz2 = 0;
0770 
0771   std::unique_ptr<genfit::MeasuredStateOnPlane> gf_state_beam_line_ca;
0772   try
0773   {
0774     gf_state_beam_line_ca.reset(phgf_track->extrapolateToLine(vertex_position, TVector3(0., 0., 1.)));
0775   }
0776   catch (...)
0777   {
0778     if (Verbosity() >= 2)
0779     {
0780       LogWarning("extrapolateToLine failed!");
0781     }
0782   }
0783   if (!gf_state_beam_line_ca)
0784   {
0785     return nullptr;
0786   }
0787 
0788   /*!
0789    *  1/p, u'/z', v'/z', u, v
0790    *  u is defined as momentum X beam line at POCA of the beam line
0791    *  v is alone the beam line
0792    *  so u is the dca2d direction
0793    */
0794 
0795   double u = gf_state_beam_line_ca->getState()[3];
0796   double v = gf_state_beam_line_ca->getState()[4];
0797 
0798   double du2 = gf_state_beam_line_ca->getCov()[3][3];
0799   double dv2 = gf_state_beam_line_ca->getCov()[4][4];
0800   // cout << PHWHERE << "        u " << u << " v " << v << " du2 " << du2 << " dv2 " << dv2 << " dvr2 " << dvr2 << endl;
0801   // delete gf_state_beam_line_ca;
0802 
0803   // create new track
0804   auto out_track = std::make_shared<SvtxTrack_v4>(*svtx_track);
0805 
0806   // clear states and insert empty one for vertex position
0807   out_track->clear_states();
0808   {
0809     /*
0810     insert first, dummy state, as done in constructor,
0811     so that the track state list is never empty. Note that insert_state, despite taking a pointer as argument,
0812     does not take ownership of the state
0813     */
0814     SvtxTrackState_v2 first(0.0);
0815     out_track->insert_state(&first);
0816   }
0817 
0818   out_track->set_dca2d(u);
0819   out_track->set_dca2d_error(sqrt(du2 + dvr2));
0820 
0821   std::unique_ptr<genfit::MeasuredStateOnPlane> gf_state_vertex_ca;
0822   try
0823   {
0824     gf_state_vertex_ca.reset(phgf_track->extrapolateToPoint(vertex_position));
0825   }
0826   catch (...)
0827   {
0828     if (Verbosity() >= 2)
0829     {
0830       LogWarning("extrapolateToPoint failed!");
0831     }
0832   }
0833   if (!gf_state_vertex_ca)
0834   {
0835     // delete out_track;
0836     return nullptr;
0837   }
0838 
0839   const auto mom = gf_state_vertex_ca->getMom();
0840   const auto pos = gf_state_vertex_ca->getPos();
0841   const auto cov = gf_state_vertex_ca->get6DCov();
0842 
0843   //    genfit::MeasuredStateOnPlane* gf_state_vertex_ca =
0844   //            phgf_track->extrapolateToLine(vertex_position,
0845   //                    TVector3(0., 0., 1.));
0846 
0847   u = gf_state_vertex_ca->getState()[3];
0848   v = gf_state_vertex_ca->getState()[4];
0849 
0850   du2 = gf_state_vertex_ca->getCov()[3][3];
0851   dv2 = gf_state_vertex_ca->getCov()[4][4];
0852 
0853   double dca3d = sqrt(square(u) + square(v));
0854   double dca3d_error = sqrt(du2 + dv2 + dvr2 + dvz2);
0855 
0856   out_track->set_dca(dca3d);
0857   out_track->set_dca_error(dca3d_error);
0858 
0859   //
0860   // in: X, Y, Z; out; r: n X Z, Z X r, Z
0861 
0862   float dca3d_xy = NAN;
0863   float dca3d_z = NAN;
0864   float dca3d_xy_error = NAN;
0865   float dca3d_z_error = NAN;
0866 
0867   try
0868   {
0869     TMatrixF pos_in(3, 1);
0870     TMatrixF cov_in(3, 3);
0871     TMatrixF pos_out(3, 1);
0872     TMatrixF cov_out(3, 3);
0873 
0874     TVectorD state6(6);      // pos(3), mom(3)
0875     TMatrixDSym cov6(6, 6);  //
0876 
0877     gf_state_vertex_ca->get6DStateCov(state6, cov6);
0878 
0879     TVector3 vn(state6[3], state6[4], state6[5]);
0880 
0881     // mean of two multivariate gaussians Pos - Vertex
0882     pos_in[0][0] = state6[0] - vertex_position.X();
0883     pos_in[1][0] = state6[1] - vertex_position.Y();
0884     pos_in[2][0] = state6[2] - vertex_position.Z();
0885 
0886     for (int i = 0; i < 3; ++i)
0887     {
0888       for (int j = 0; j < 3; ++j)
0889       {
0890         cov_in[i][j] = cov6[i][j] + vertex_cov[i][j];
0891       }
0892     }
0893 
0894     // vn is momentum vector, pos_in is position vector (of what?)
0895     pos_cov_XYZ_to_RZ(vn, pos_in, cov_in, pos_out, cov_out);
0896 
0897     if (Verbosity() > 30)
0898     {
0899       cout << " vn.X " << vn.X() << " vn.Y " << vn.Y() << " vn.Z " << vn.Z() << endl;
0900       cout << " pos_in.X " << pos_in[0][0] << " pos_in.Y " << pos_in[1][0] << " pos_in.Z " << pos_in[2][0] << endl;
0901       cout << " pos_out.X " << pos_out[0][0] << " pos_out.Y " << pos_out[1][0] << " pos_out.Z " << pos_out[2][0] << endl;
0902     }
0903 
0904     dca3d_xy = pos_out[0][0];
0905     dca3d_z = pos_out[2][0];
0906     dca3d_xy_error = sqrt(cov_out[0][0]);
0907     dca3d_z_error = sqrt(cov_out[2][2]);
0908 
0909 #ifdef _DEBUG_
0910     cout << __LINE__ << ": Vertex: ----------------" << endl;
0911     vertex_position.Print();
0912     vertex_cov.Print();
0913 
0914     cout << __LINE__ << ": State: ----------------" << endl;
0915     state6.Print();
0916     cov6.Print();
0917 
0918     cout << __LINE__ << ": Mean: ----------------" << endl;
0919     pos_in.Print();
0920     cout << "===>" << endl;
0921     pos_out.Print();
0922 
0923     cout << __LINE__ << ": Cov: ----------------" << endl;
0924     cov_in.Print();
0925     cout << "===>" << endl;
0926     cov_out.Print();
0927 
0928     cout << endl;
0929 #endif
0930   }
0931   catch (...)
0932   {
0933     if (Verbosity())
0934     {
0935       LogWarning("DCA calculationfailed!");
0936     }
0937   }
0938 
0939   out_track->set_dca3d_xy(dca3d_xy);
0940   out_track->set_dca3d_z(dca3d_z);
0941   out_track->set_dca3d_xy_error(dca3d_xy_error);
0942   out_track->set_dca3d_z_error(dca3d_z_error);
0943 
0944   // if(gf_state_vertex_ca) delete gf_state_vertex_ca;
0945 
0946   out_track->set_chisq(chi2);
0947   out_track->set_ndf(ndf);
0948   out_track->set_charge(phgf_track->get_charge());
0949 
0950   out_track->set_px(mom.Px());
0951   out_track->set_py(mom.Py());
0952   out_track->set_pz(mom.Pz());
0953 
0954   out_track->set_x(pos.X());
0955   out_track->set_y(pos.Y());
0956   out_track->set_z(pos.Z());
0957 
0958   for (int i = 0; i < 6; i++)
0959   {
0960     for (int j = i; j < 6; j++)
0961     {
0962       out_track->set_error(i, j, cov[i][j]);
0963     }
0964   }
0965 
0966 #ifdef _DEBUG_
0967   cout << __LINE__ << endl;
0968 #endif
0969 
0970   const auto gftrack = phgf_track->getGenFitTrack();
0971   const auto rep = gftrack->getCardinalRep();
0972   for (unsigned int id = 0; id < gftrack->getNumPointsWithMeasurement(); ++id)
0973   {
0974     genfit::TrackPoint* trpoint = gftrack->getPointWithMeasurementAndFitterInfo(id, gftrack->getCardinalRep());
0975 
0976     if (!trpoint)
0977     {
0978       if (Verbosity() > 1) LogWarning("!trpoint");
0979       continue;
0980     }
0981 
0982     auto kfi = static_cast<genfit::KalmanFitterInfo*>(trpoint->getFitterInfo(rep));
0983     if (!kfi)
0984     {
0985       if (Verbosity() > 1) LogWarning("!kfi");
0986       continue;
0987     }
0988 
0989     const genfit::MeasuredStateOnPlane* gf_state = nullptr;
0990     try
0991     {
0992       // this works because KalmanFitterInfo returns a const reference to internal object and not a temporary object
0993       gf_state = &kfi->getFittedState(true);
0994     }
0995     catch (...)
0996     {
0997       if (Verbosity() >= 1)
0998         LogWarning("Exrapolation failed!");
0999     }
1000     if (!gf_state)
1001     {
1002       if (Verbosity() >= 1)
1003         LogWarning("Exrapolation failed!");
1004       continue;
1005     }
1006     genfit::MeasuredStateOnPlane temp;
1007     float pathlength = -phgf_track->extrapolateToPoint(temp, vertex_position, id);
1008 
1009     // create new svtx state and add to track
1010     auto state = create_track_state(pathlength, gf_state);
1011 
1012     // get matching cluster key from phgf_track and assign to state
1013     state.set_cluskey(phgf_track->get_cluster_keys()[id]);
1014 
1015     out_track->insert_state(&state);
1016 
1017 #ifdef _DEBUG_
1018     cout
1019         << __LINE__
1020         << ": " << id
1021         << ": " << pathlength << " => "
1022         << sqrt(square(state->get_x()) + square(state->get_y()))
1023         << endl;
1024 #endif
1025   }
1026 
1027   // loop over clusters, check if layer is disabled, include extrapolated SvtxTrackState
1028   if (!_disabled_layers.empty())
1029   {
1030     // get crossing
1031     const auto crossing = svtx_track->get_crossing();
1032     assert(crossing != SHRT_MAX);
1033 
1034     unsigned int id_min = 0;
1035     for (const auto& cluster_key : get_cluster_keys(svtx_track))
1036     {
1037       const auto cluster = m_clustermap->findCluster(cluster_key);
1038       const auto layer = TrkrDefs::getLayer(cluster_key);
1039 
1040       // skip enabled layers
1041       if (_disabled_layers.find(layer) == _disabled_layers.end())
1042       {
1043         continue;
1044       }
1045 
1046       // get position
1047       const auto globalPosition = m_globalPositionWrapper.getGlobalPositionDistortionCorrected( cluster_key, cluster, crossing );
1048       const TVector3 pos_A(globalPosition.x(), globalPosition.y(), globalPosition.z() );
1049       const float r_cluster = std::sqrt( square(globalPosition.x()) + square(globalPosition.y()) );
1050 
1051       // loop over states
1052       /* find first state whose radius is larger than that of cluster if any */
1053       unsigned int id = id_min;
1054       for (; id < gftrack->getNumPointsWithMeasurement(); ++id)
1055       {
1056         auto trpoint = gftrack->getPointWithMeasurementAndFitterInfo(id, rep);
1057         if (!trpoint) continue;
1058 
1059         auto kfi = static_cast<genfit::KalmanFitterInfo*>(trpoint->getFitterInfo(rep));
1060         if (!kfi) continue;
1061 
1062         const genfit::MeasuredStateOnPlane* gf_state = nullptr;
1063         try
1064         {
1065           gf_state = &kfi->getFittedState(true);
1066         }
1067         catch (...)
1068         {
1069           if (Verbosity())
1070           {
1071             LogWarning("Failed to get kf fitted state");
1072           }
1073         }
1074 
1075         if (!gf_state) continue;
1076 
1077         float r_track = std::sqrt(square(gf_state->getPos().x()) + square(gf_state->getPos().y()));
1078         if (r_track > r_cluster) break;
1079       }
1080 
1081       // forward extrapolation
1082       genfit::MeasuredStateOnPlane gf_state;
1083       float pathlength = 0;
1084 
1085       // first point is previous, if valid
1086       if (id > 0) id_min = id - 1;
1087 
1088       // extrapolate forward
1089       try
1090       {
1091         auto trpoint = gftrack->getPointWithMeasurementAndFitterInfo(id_min, rep);
1092         if (!trpoint) continue;
1093 
1094         auto kfi = static_cast<genfit::KalmanFitterInfo*>(trpoint->getFitterInfo(rep));
1095         gf_state = *kfi->getForwardUpdate();
1096         pathlength = gf_state.extrapolateToPoint( pos_A );
1097         auto tmp = *kfi->getBackwardUpdate();
1098         pathlength -= tmp.extrapolateToPoint(vertex_position);
1099       }
1100       catch (...)
1101       {
1102         if (Verbosity())
1103         {
1104           std::cerr << PHWHERE << "Failed to forward extrapolate from id " << id_min << " to disabled layer " << layer << std::endl;
1105         }
1106         continue;
1107       }
1108 
1109       // also extrapolate backward from next state if any
1110       // and take the weighted average between both points
1111       if (id > 0 && id < gftrack->getNumPointsWithMeasurement())
1112         try
1113         {
1114           auto trpoint = gftrack->getPointWithMeasurementAndFitterInfo(id, rep);
1115           if (!trpoint) continue;
1116 
1117           auto kfi = static_cast<genfit::KalmanFitterInfo*>(trpoint->getFitterInfo(rep));
1118           genfit::KalmanFittedStateOnPlane gf_state_backward = *kfi->getBackwardUpdate();
1119           gf_state_backward.extrapolateToPlane(gf_state.getPlane());
1120           gf_state = genfit::calcAverageState(gf_state, gf_state_backward);
1121         }
1122         catch (...)
1123         {
1124           if (Verbosity())
1125           {
1126             std::cerr << PHWHERE << "Failed to backward extrapolate from id " << id << " to disabled layer " << layer << std::endl;
1127           }
1128           continue;
1129         }
1130 
1131       // create new svtx state and add to track
1132       auto state = create_track_state(pathlength, &gf_state);
1133       state.set_cluskey(cluster_key);
1134       out_track->insert_state(&state);
1135     }
1136   }
1137 
1138   // printout all track state
1139   if (Verbosity())
1140   {
1141     for (auto&& iter = out_track->begin_states(); iter != out_track->end_states(); ++iter)
1142     {
1143       const auto& [pathlength, state] = *iter;
1144       const auto r = std::sqrt(square(state->get_x()) + square(state->get_y()));
1145       const auto phi = std::atan2(state->get_y(), state->get_x());
1146       std::cout << "PHGenFitTrkFitter::MakeSvtxTrack -"
1147                 << " pathlength: " << pathlength
1148                 << " radius: " << r
1149                 << " phi: " << phi
1150                 << " z: " << state->get_z()
1151                 << std::endl;
1152     }
1153 
1154     std::cout << std::endl;
1155   }
1156   return out_track;
1157 }
1158 
1159 //______________________________________________________________________
1160 bool PHGenFitTrkFitter::pos_cov_XYZ_to_RZ(
1161     const TVector3& n, const TMatrixF& pos_in, const TMatrixF& cov_in,
1162     TMatrixF& pos_out, TMatrixF& cov_out) const
1163 {
1164   if (pos_in.GetNcols() != 1 || pos_in.GetNrows() != 3)
1165   {
1166     if (Verbosity())
1167     {
1168       LogWarning("pos_in.GetNcols() != 1 || pos_in.GetNrows() != 3");
1169     }
1170     return false;
1171   }
1172 
1173   if (cov_in.GetNcols() != 3 || cov_in.GetNrows() != 3)
1174   {
1175     if (Verbosity())
1176     {
1177       LogWarning("cov_in.GetNcols() != 3 || cov_in.GetNrows() != 3");
1178     }
1179     return false;
1180   }
1181 
1182   // produces a vector perpendicular to both the momentum vector and beam line - i.e. in the direction of the dca_xy
1183   // only the angle of r will be used, not the magnitude
1184   TVector3 r = n.Cross(TVector3(0., 0., 1.));
1185   if (r.Mag() < 0.00001)
1186   {
1187     if (Verbosity())
1188     {
1189       LogWarning("n is parallel to z");
1190     }
1191     return false;
1192   }
1193 
1194   // R: rotation from u,v,n to n X Z, nX(nXZ), n
1195   TMatrixF R(3, 3);
1196   TMatrixF R_T(3, 3);
1197 
1198   try
1199   {
1200     // rotate u along z to up
1201     float phi = -TMath::ATan2(r.Y(), r.X());
1202     R[0][0] = cos(phi);
1203     R[0][1] = -sin(phi);
1204     R[0][2] = 0;
1205     R[1][0] = sin(phi);
1206     R[1][1] = cos(phi);
1207     R[1][2] = 0;
1208     R[2][0] = 0;
1209     R[2][1] = 0;
1210     R[2][2] = 1;
1211 
1212     R_T.Transpose(R);
1213   }
1214   catch (...)
1215   {
1216     if (Verbosity())
1217     {
1218       LogWarning("Can't get rotation matrix");
1219     }
1220 
1221     return false;
1222   }
1223 
1224   pos_out.ResizeTo(3, 1);
1225   cov_out.ResizeTo(3, 3);
1226 
1227   pos_out = R * pos_in;
1228   cov_out = R * cov_in * R_T;
1229 
1230   return true;
1231 }