Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:06

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2021-2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Definitions/Algebra.hpp"
0010 #include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp"
0011 #include "Acts/Plugins/Python/Utilities.hpp"
0012 #include "Acts/TrackFitting/BetheHeitlerApprox.hpp"
0013 #include "Acts/TrackFitting/GsfOptions.hpp"
0014 #include "Acts/Utilities/Logger.hpp"
0015 #include "ActsExamples/EventData/Cluster.hpp"
0016 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0017 #include "ActsExamples/EventData/ScalingCalibrator.hpp"
0018 #include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp"
0019 #include "ActsExamples/TrackFitting/SurfaceSortingAlgorithm.hpp"
0020 #include "ActsExamples/TrackFitting/TrackFitterFunction.hpp"
0021 #include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp"
0022 
0023 #include <cstddef>
0024 #include <memory>
0025 #include <vector>
0026 
0027 #include <pybind11/pybind11.h>
0028 #include <pybind11/stl.h>
0029 
0030 namespace Acts {
0031 class MagneticFieldProvider;
0032 class TrackingGeometry;
0033 }  // namespace Acts
0034 namespace ActsExamples {
0035 class IAlgorithm;
0036 }  // namespace ActsExamples
0037 
0038 namespace py = pybind11;
0039 
0040 using namespace ActsExamples;
0041 using namespace Acts;
0042 
0043 namespace Acts::Python {
0044 
0045 void addTrackFitting(Context& ctx) {
0046   auto mex = ctx.get("examples");
0047 
0048   ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::SurfaceSortingAlgorithm, mex,
0049                                 "SurfaceSortingAlgorithm", inputProtoTracks,
0050                                 inputSimHits, inputMeasurementSimHitsMap,
0051                                 outputProtoTracks);
0052 
0053   ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::TrackFittingAlgorithm, mex,
0054                                 "TrackFittingAlgorithm", inputMeasurements,
0055                                 inputSourceLinks, inputProtoTracks,
0056                                 inputInitialTrackParameters, inputClusters,
0057                                 outputTracks, fit, pickTrack, calibrator);
0058 
0059   ACTS_PYTHON_DECLARE_ALGORITHM(ActsExamples::RefittingAlgorithm, mex,
0060                                 "RefittingAlgorithm", inputTracks, outputTracks,
0061                                 fit, pickTrack);
0062 
0063   {
0064     py::class_<TrackFitterFunction, std::shared_ptr<TrackFitterFunction>>(
0065         mex, "TrackFitterFunction");
0066 
0067     mex.def(
0068         "makeKalmanFitterFunction",
0069         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0070            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0071            bool multipleScattering, bool energyLoss,
0072            double reverseFilteringMomThreshold,
0073            Acts::FreeToBoundCorrection freeToBoundCorrection,
0074            Logging::Level level) {
0075           return ActsExamples::makeKalmanFitterFunction(
0076               trackingGeometry, magneticField, multipleScattering, energyLoss,
0077               reverseFilteringMomThreshold, freeToBoundCorrection,
0078               *Acts::getDefaultLogger("Kalman", level));
0079         },
0080         py::arg("trackingGeometry"), py::arg("magneticField"),
0081         py::arg("multipleScattering"), py::arg("energyLoss"),
0082         py::arg("reverseFilteringMomThreshold"),
0083         py::arg("freeToBoundCorrection"), py::arg("level"));
0084 
0085     py::class_<MeasurementCalibrator, std::shared_ptr<MeasurementCalibrator>>(
0086         mex, "MeasurementCalibrator");
0087 
0088     mex.def("makePassThroughCalibrator",
0089             []() -> std::shared_ptr<MeasurementCalibrator> {
0090               return std::make_shared<PassThroughCalibrator>();
0091             });
0092 
0093     mex.def(
0094         "makeScalingCalibrator",
0095         [](const char* path) -> std::shared_ptr<MeasurementCalibrator> {
0096           return std::make_shared<ActsExamples::ScalingCalibrator>(path);
0097         },
0098         py::arg("path"));
0099 
0100     py::enum_<Acts::ComponentMergeMethod>(mex, "ComponentMergeMethod")
0101         .value("mean", Acts::ComponentMergeMethod::eMean)
0102         .value("maxWeight", Acts::ComponentMergeMethod::eMaxWeight);
0103 
0104     py::enum_<ActsExamples::MixtureReductionAlgorithm>(
0105         mex, "MixtureReductionAlgorithm")
0106         .value("weightCut", MixtureReductionAlgorithm::weightCut)
0107         .value("KLDistance", MixtureReductionAlgorithm::KLDistance);
0108 
0109     py::class_<ActsExamples::BetheHeitlerApprox>(mex, "AtlasBetheHeitlerApprox")
0110         .def_static("loadFromFiles",
0111                     &ActsExamples::BetheHeitlerApprox::loadFromFiles,
0112                     py::arg("lowParametersPath"), py::arg("highParametersPath"),
0113                     py::arg("lowLimit") = 0.1, py::arg("highLimit") = 0.2)
0114         .def_static("makeDefault",
0115                     []() { return Acts::makeDefaultBetheHeitlerApprox(); });
0116     mex.def(
0117         "makeGsfFitterFunction",
0118         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0119            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0120            BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents,
0121            double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod,
0122            ActsExamples::MixtureReductionAlgorithm mixtureReductionAlgorithm,
0123            Logging::Level level) {
0124           return ActsExamples::makeGsfFitterFunction(
0125               trackingGeometry, magneticField, betheHeitlerApprox,
0126               maxComponents, weightCutoff, componentMergeMethod,
0127               mixtureReductionAlgorithm,
0128               *Acts::getDefaultLogger("GSFFunc", level));
0129         },
0130         py::arg("trackingGeometry"), py::arg("magneticField"),
0131         py::arg("betheHeitlerApprox"), py::arg("maxComponents"),
0132         py::arg("weightCutoff"), py::arg("componentMergeMethod"),
0133         py::arg("mixtureReductionAlgorithm"), py::arg("level"));
0134 
0135     mex.def(
0136         "makeGlobalChiSquareFitterFunction",
0137         [](std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry,
0138            std::shared_ptr<const Acts::MagneticFieldProvider> magneticField,
0139            bool multipleScattering, bool energyLoss,
0140            Acts::FreeToBoundCorrection freeToBoundCorrection,
0141            std::size_t nUpdateMax, bool zeroField, double relChi2changeCutOff,
0142            Logging::Level level) {
0143           return ActsExamples::makeGlobalChiSquareFitterFunction(
0144               trackingGeometry, magneticField, multipleScattering, energyLoss,
0145               freeToBoundCorrection, nUpdateMax, zeroField, relChi2changeCutOff,
0146               *Acts::getDefaultLogger("Gx2f", level));
0147         },
0148         py::arg("trackingGeometry"), py::arg("magneticField"),
0149         py::arg("multipleScattering"), py::arg("energyLoss"),
0150         py::arg("freeToBoundCorrection"), py::arg("nUpdateMax"),
0151         py::arg("zeroField"), py::arg("relChi2changeCutOff"), py::arg("level"));
0152   }
0153 
0154   {
0155     py::class_<FreeToBoundCorrection>(mex, "FreeToBoundCorrection")
0156         .def(py::init<>())
0157         .def(py::init<bool>(), py::arg("apply") = false)
0158         .def(py::init<bool, double, double>(), py::arg("apply") = false,
0159              py::arg("alpha") = 0.1, py::arg("beta") = 2);
0160   }
0161 }
0162 
0163 }  // namespace Acts::Python