File indexing completed on 2025-08-05 08:09:49
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Definitions/Common.hpp"
0010 #include "Acts/Definitions/Direction.hpp"
0011 #include "Acts/EventData/MultiTrajectory.hpp"
0012 #include "Acts/EventData/TrackContainer.hpp"
0013 #include "Acts/EventData/TrackParameters.hpp"
0014 #include "Acts/EventData/TrackStatePropMask.hpp"
0015 #include "Acts/EventData/VectorMultiTrajectory.hpp"
0016 #include "Acts/EventData/VectorTrackContainer.hpp"
0017 #include "Acts/Geometry/GeometryIdentifier.hpp"
0018 #include "Acts/Propagator/DirectNavigator.hpp"
0019 #include "Acts/Propagator/MultiEigenStepperLoop.hpp"
0020 #include "Acts/Propagator/Navigator.hpp"
0021 #include "Acts/Propagator/Propagator.hpp"
0022 #include "Acts/TrackFitting/GainMatrixUpdater.hpp"
0023 #include "Acts/TrackFitting/GaussianSumFitter.hpp"
0024 #include "Acts/TrackFitting/GsfMixtureReduction.hpp"
0025 #include "Acts/TrackFitting/GsfOptions.hpp"
0026 #include "Acts/Utilities/Delegate.hpp"
0027 #include "Acts/Utilities/HashedString.hpp"
0028 #include "Acts/Utilities/Intersection.hpp"
0029 #include "Acts/Utilities/Logger.hpp"
0030 #include "Acts/Utilities/Zip.hpp"
0031 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0032 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0033 #include "ActsExamples/EventData/Track.hpp"
0034 #include "ActsExamples/TrackFitting/RefittingCalibrator.hpp"
0035 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0036
0037 #include <algorithm>
0038 #include <array>
0039 #include <cstddef>
0040 #include <map>
0041 #include <memory>
0042 #include <optional>
0043 #include <string>
0044 #include <string_view>
0045 #include <tuple>
0046 #include <utility>
0047 #include <vector>
0048
0049 namespace Acts {
0050 class MagneticFieldProvider;
0051 class SourceLink;
0052 class Surface;
0053 class TrackingGeometry;
0054 }
0055
0056 using namespace ActsExamples;
0057
0058 namespace {
0059
0060 using MultiStepper = Acts::MultiEigenStepperLoop<>;
0061 using Propagator = Acts::Propagator<MultiStepper, Acts::Navigator>;
0062 using DirectPropagator = Acts::Propagator<MultiStepper, Acts::DirectNavigator>;
0063
0064 using Fitter = Acts::GaussianSumFitter<Propagator, BetheHeitlerApprox,
0065 Acts::VectorMultiTrajectory>;
0066 using DirectFitter =
0067 Acts::GaussianSumFitter<DirectPropagator, BetheHeitlerApprox,
0068 Acts::VectorMultiTrajectory>;
0069 using TrackContainer =
0070 Acts::TrackContainer<Acts::VectorTrackContainer,
0071 Acts::VectorMultiTrajectory, std::shared_ptr>;
0072
0073 struct GsfFitterFunctionImpl final : public ActsExamples::TrackFitterFunction {
0074 Fitter fitter;
0075 DirectFitter directFitter;
0076
0077 Acts::GainMatrixUpdater updater;
0078
0079 std::size_t maxComponents = 0;
0080 double weightCutoff = 0;
0081 const double momentumCutoff = 0;
0082 bool abortOnError = false;
0083 bool disableAllMaterialHandling = false;
0084 MixtureReductionAlgorithm reductionAlg =
0085 MixtureReductionAlgorithm::KLDistance;
0086 Acts::ComponentMergeMethod mergeMethod =
0087 Acts::ComponentMergeMethod::eMaxWeight;
0088
0089 IndexSourceLink::SurfaceAccessor m_slSurfaceAccessor;
0090
0091 GsfFitterFunctionImpl(Fitter&& f, DirectFitter&& df,
0092 const Acts::TrackingGeometry& trkGeo)
0093 : fitter(std::move(f)),
0094 directFitter(std::move(df)),
0095 m_slSurfaceAccessor{trkGeo} {}
0096
0097 template <typename calibrator_t>
0098 auto makeGsfOptions(const GeneralFitterOptions& options,
0099 const calibrator_t& calibrator) const {
0100 Acts::GsfExtensions<Acts::VectorMultiTrajectory> extensions;
0101 extensions.updater.connect<
0102 &Acts::GainMatrixUpdater::operator()<Acts::VectorMultiTrajectory>>(
0103 &updater);
0104
0105 Acts::GsfOptions<Acts::VectorMultiTrajectory> gsfOptions{
0106 options.geoContext,
0107 options.magFieldContext,
0108 options.calibrationContext,
0109 extensions,
0110 options.propOptions,
0111 &(*options.referenceSurface),
0112 maxComponents,
0113 weightCutoff,
0114 abortOnError,
0115 disableAllMaterialHandling};
0116 gsfOptions.componentMergeMethod = mergeMethod;
0117
0118 gsfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>(
0119 &calibrator);
0120 gsfOptions.extensions.surfaceAccessor
0121 .connect<&IndexSourceLink::SurfaceAccessor::operator()>(
0122 &m_slSurfaceAccessor);
0123 switch (reductionAlg) {
0124 case MixtureReductionAlgorithm::weightCut: {
0125 gsfOptions.extensions.mixtureReducer
0126 .connect<&Acts::reduceMixtureLargestWeights>();
0127 } break;
0128 case MixtureReductionAlgorithm::KLDistance: {
0129 gsfOptions.extensions.mixtureReducer
0130 .connect<&Acts::reduceMixtureWithKLDistance>();
0131 } break;
0132 }
0133
0134 return gsfOptions;
0135 }
0136
0137 TrackFitterResult operator()(const std::vector<Acts::SourceLink>& sourceLinks,
0138 const TrackParameters& initialParameters,
0139 const GeneralFitterOptions& options,
0140 const MeasurementCalibratorAdapter& calibrator,
0141 TrackContainer& tracks) const override {
0142 const auto gsfOptions = makeGsfOptions(options, calibrator);
0143
0144 using namespace Acts::GsfConstants;
0145 if (!tracks.hasColumn(Acts::hashString(kFinalMultiComponentStateColumn))) {
0146 std::string key(kFinalMultiComponentStateColumn);
0147 tracks.template addColumn<FinalMultiComponentState>(key);
0148 }
0149
0150 if (!tracks.hasColumn(Acts::hashString(kFwdMaxMaterialXOverX0))) {
0151 tracks.template addColumn<double>(std::string(kFwdMaxMaterialXOverX0));
0152 }
0153
0154 if (!tracks.hasColumn(Acts::hashString(kFwdSumMaterialXOverX0))) {
0155 tracks.template addColumn<double>(std::string(kFwdSumMaterialXOverX0));
0156 }
0157
0158 return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters,
0159 gsfOptions, tracks);
0160 }
0161
0162 TrackFitterResult operator()(
0163 const std::vector<Acts::SourceLink>& sourceLinks,
0164 const TrackParameters& initialParameters,
0165 const GeneralFitterOptions& options,
0166 const RefittingCalibrator& calibrator,
0167 const std::vector<const Acts::Surface*>& surfaceSequence,
0168 TrackContainer& tracks) const override {
0169 const auto gsfOptions = makeGsfOptions(options, calibrator);
0170
0171 using namespace Acts::GsfConstants;
0172 if (!tracks.hasColumn(Acts::hashString(kFinalMultiComponentStateColumn))) {
0173 std::string key(kFinalMultiComponentStateColumn);
0174 tracks.template addColumn<FinalMultiComponentState>(key);
0175 }
0176
0177 return directFitter.fit(sourceLinks.begin(), sourceLinks.end(),
0178 initialParameters, gsfOptions, surfaceSequence,
0179 tracks);
0180 }
0181 };
0182
0183 }
0184
0185 std::shared_ptr<TrackFitterFunction> ActsExamples::makeGsfFitterFunction(
0186 std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0187 std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0188 BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
0189 double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod,
0190 MixtureReductionAlgorithm mixtureReductionAlgorithm,
0191 const Acts::Logger& logger) {
0192
0193 MultiStepper stepper(magneticField, logger.cloneWithSuffix("Step"));
0194 const auto& geo = *trackingGeometry;
0195 Acts::Navigator::Config cfg{std::move(trackingGeometry)};
0196 cfg.resolvePassive = false;
0197 cfg.resolveMaterial = true;
0198 cfg.resolveSensitive = true;
0199 Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator"));
0200 Propagator propagator(std::move(stepper), std::move(navigator),
0201 logger.cloneWithSuffix("Propagator"));
0202 Fitter trackFitter(std::move(propagator),
0203 BetheHeitlerApprox(betheHeitlerApprox),
0204 logger.cloneWithSuffix("GSF"));
0205
0206
0207 MultiStepper directStepper(std::move(magneticField),
0208 logger.cloneWithSuffix("Step"));
0209 Acts::DirectNavigator directNavigator{
0210 logger.cloneWithSuffix("DirectNavigator")};
0211 DirectPropagator directPropagator(std::move(directStepper),
0212 std::move(directNavigator),
0213 logger.cloneWithSuffix("DirectPropagator"));
0214 DirectFitter directTrackFitter(std::move(directPropagator),
0215 BetheHeitlerApprox(betheHeitlerApprox),
0216 logger.cloneWithSuffix("DirectGSF"));
0217
0218
0219 auto fitterFunction = std::make_shared<GsfFitterFunctionImpl>(
0220 std::move(trackFitter), std::move(directTrackFitter), geo);
0221 fitterFunction->maxComponents = maxComponents;
0222 fitterFunction->weightCutoff = weightCutoff;
0223 fitterFunction->mergeMethod = componentMergeMethod;
0224 fitterFunction->reductionAlg = mixtureReductionAlgorithm;
0225
0226 return fitterFunction;
0227 }