File indexing completed on 2025-08-05 08:10:06
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Definitions/Algebra.hpp"
0010 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0011 #include "Acts/Plugins/Python/Utilities.hpp"
0012 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0013 #include "Acts/TrackFitting/GsfOptions.hpp"
0014 #include "Acts/Utilities/Logger.hpp"
0015 #include "ActsExamples/EventData/Cluster.hpp"
0016 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0017 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0018 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0019 #include "ActsExamples/TrackFitting/SurfaceSortingAlgorithm.hpp"
0020 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0021 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0022
0023 #include <cstddef>
0024 #include <memory>
0025 #include <vector>
0026
0027 #include <pybind11/pybind11.h>
0028 #include <pybind11/stl.h>
0029
0030 namespace Acts {
0031 class MagneticFieldProvider;
0032 class TrackingGeometry;
0033 }
0034 namespace ActsExamples {
0035 class IAlgorithm;
0036 }
0037
0038 namespace py = pybind11;
0039
0040 using namespace ActsExamples;
0041 using namespace Acts;
0042
0043 namespace Acts::Python {
0044
0045 void addTrackFitting(Context& ctx) {
0046 auto mex = ctx.get("examples");
0047
0048 ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::SurfaceSortingAlgorithm, mex,
0049 "SurfaceSortingAlgorithm", inputProtoTracks,
0050 inputSimHits, inputMeasurementSimHitsMap,
0051 outputProtoTracks);
0052
0053 ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::TrackFittingAlgorithm, mex,
0054 "TrackFittingAlgorithm", inputMeasurements,
0055 inputSourceLinks, inputProtoTracks,
0056 inputInitialTrackParameters, inputClusters,
0057 outputTracks, fit, pickTrack, calibrator);
0058
0059 ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::RefittingAlgorithm, mex,
0060 "RefittingAlgorithm", inputTracks, outputTracks,
0061 fit, pickTrack);
0062
0063 {
0064 py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0065 mex, "TrackFitterFunction");
0066
0067 mex.def(
0068 "makeKalmanFitterFunction",
0069 [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0070 std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0071 bool multipleScattering, bool energyLoss,
0072 double reverseFilteringMomThreshold,
0073 Acts::FreeToBoundCorrection freeToBoundCorrection,
0074 Logging::Level level) {
0075 return ActsExamples::makeKalmanFitterFunction(
0076 trackingGeometry, magneticField, multipleScattering, energyLoss,
0077 reverseFilteringMomThreshold, freeToBoundCorrection,
0078 *Acts::getDefaultLogger("Kalman", level));
0079 },
0080 py::arg("trackingGeometry"), py::arg("magneticField"),
0081 py::arg("multipleScattering"), py::arg("energyLoss"),
0082 py::arg("reverseFilteringMomThreshold"),
0083 py::arg("freeToBoundCorrection"), py::arg("level"));
0084
0085 py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0086 mex, "MeasurementCalibrator");
0087
0088 mex.def("makePassThroughCalibrator",
0089 []() -> std::shared_ptr<MeasurementCalibrator> {
0090 return std::make_shared<PassThroughCalibrator>();
0091 });
0092
0093 mex.def(
0094 "makeScalingCalibrator",
0095 [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0096 return std::make_shared<ActsExamples::ScalingCalibrator>(path);
0097 },
0098 py::arg("path"));
0099
0100 py::enum_<Acts::ComponentMergeMethod>(mex, "ComponentMergeMethod")
0101 .value("mean", Acts::ComponentMergeMethod::eMean)
0102 .value("maxWeight", Acts::ComponentMergeMethod::eMaxWeight);
0103
0104 py::enum_<ActsExamples::MixtureReductionAlgorithm>(
0105 mex, "MixtureReductionAlgorithm")
0106 .value("weightCut", MixtureReductionAlgorithm::weightCut)
0107 .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0108
0109 py::class_<ActsExamples::BetheHeitlerApprox>(mex, "AtlasBetheHeitlerApprox")
0110 .def_static("loadFromFiles",
0111 &ActsExamples::BetheHeitlerApprox::loadFromFiles,
0112 py::arg("lowParametersPath"), py::arg("highParametersPath"),
0113 py::arg("lowLimit") = 0.1, py::arg("highLimit") = 0.2)
0114 .def_static("makeDefault",
0115 []() { return Acts::makeDefaultBetheHeitlerApprox(); });
0116 mex.def(
0117 "makeGsfFitterFunction",
0118 [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0119 std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0120 BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
0121 double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod,
0122 ActsExamples::MixtureReductionAlgorithm mixtureReductionAlgorithm,
0123 Logging::Level level) {
0124 return ActsExamples::makeGsfFitterFunction(
0125 trackingGeometry, magneticField, betheHeitlerApprox,
0126 maxComponents, weightCutoff, componentMergeMethod,
0127 mixtureReductionAlgorithm,
0128 *Acts::getDefaultLogger("GSFFunc", level));
0129 },
0130 py::arg("trackingGeometry"), py::arg("magneticField"),
0131 py::arg("betheHeitlerApprox"), py::arg("maxComponents"),
0132 py::arg("weightCutoff"), py::arg("componentMergeMethod"),
0133 py::arg("mixtureReductionAlgorithm"), py::arg("level"));
0134
0135 mex.def(
0136 "makeGlobalChiSquareFitterFunction",
0137 [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0138 std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0139 bool multipleScattering, bool energyLoss,
0140 Acts::FreeToBoundCorrection freeToBoundCorrection,
0141 std::size_t nUpdateMax, bool zeroField, double relChi2changeCutOff,
0142 Logging::Level level) {
0143 return ActsExamples::makeGlobalChiSquareFitterFunction(
0144 trackingGeometry, magneticField, multipleScattering, energyLoss,
0145 freeToBoundCorrection, nUpdateMax, zeroField, relChi2changeCutOff,
0146 *Acts::getDefaultLogger("Gx2f", level));
0147 },
0148 py::arg("trackingGeometry"), py::arg("magneticField"),
0149 py::arg("multipleScattering"), py::arg("energyLoss"),
0150 py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0151 py::arg("zeroField"), py::arg("relChi2changeCutOff"), py::arg("level"));
0152 }
0153
0154 {
0155 py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0156 .def(py::init<>())
0157 .def(py::init<bool>(), py::arg("apply") = false)
0158 .def(py::init<bool, double, double>(), py::arg("apply") = false,
0159 py::arg("alpha") = 0.1, py::arg("beta") = 2);
0160 }
0161 }
0162
0163 }