File indexing completed on 2025-08-05 08:10:04
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp"
0011 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0012 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0013 #include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
0014 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0015 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0016 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0017 #include "Acts/Plugins/Python/Utilities.hpp"
0018 #include "ActsExamples/TrackFindingExaTrkX/PrototracksToParameters.hpp"
0019 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"
0020 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingFromPrototrackAlgorithm.hpp"
0021
0022 #include <memory>
0023
0024 #include <pybind11/functional.h>
0025 #include <pybind11/pybind11.h>
0026 #include <pybind11/stl.h>
0027
0028 namespace py = pybind11;
0029
0030 using namespace ActsExamples;
0031 using namespace Acts;
0032
0033 namespace Acts::Python {
0034
0035 void addExaTrkXTrackFinding(Context &ctx) {
0036 auto [m, mex] = ctx.get("main", "examples");
0037
0038 {
0039 using C = Acts::GraphConstructionBase;
0040 auto c = py::class_<C, std::shared_ptr<C>>(mex, "GraphConstructionBase");
0041 }
0042 {
0043 using C = Acts::EdgeClassificationBase;
0044 auto c = py::class_<C, std::shared_ptr<C>>(mex, "EdgeClassificationBase");
0045 }
0046 {
0047 using C = Acts::TrackBuildingBase;
0048 auto c = py::class_<C, std::shared_ptr<C>>(mex, "TrackBuildingBase");
0049 }
0050
0051 #ifdef ACTS_EXATRKX_TORCH_BACKEND
0052 {
0053 using Alg = Acts::TorchMetricLearning;
0054 using Config = Alg::Config;
0055
0056 auto alg =
0057 py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0058 mex, "TorchMetricLearning")
0059 .def(py::init([](const Config &c, Logging::Level lvl) {
0060 return std::make_shared<Alg>(
0061 c, getDefaultLogger("MetricLearning", lvl));
0062 }),
0063 py::arg("config"), py::arg("level"))
0064 .def_property_readonly("config", &Alg::config);
0065
0066 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0067 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0068 ACTS_PYTHON_MEMBER(modelPath);
0069 ACTS_PYTHON_MEMBER(numFeatures);
0070 ACTS_PYTHON_MEMBER(embeddingDim);
0071 ACTS_PYTHON_MEMBER(rVal);
0072 ACTS_PYTHON_MEMBER(knnVal);
0073 ACTS_PYTHON_STRUCT_END();
0074 }
0075 {
0076 using Alg = Acts::TorchEdgeClassifier;
0077 using Config = Alg::Config;
0078
0079 auto alg =
0080 py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0081 mex, "TorchEdgeClassifier")
0082 .def(py::init([](const Config &c, Logging::Level lvl) {
0083 return std::make_shared<Alg>(
0084 c, getDefaultLogger("EdgeClassifier", lvl));
0085 }),
0086 py::arg("config"), py::arg("level"))
0087 .def_property_readonly("config", &Alg::config);
0088
0089 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0090 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0091 ACTS_PYTHON_MEMBER(modelPath);
0092 ACTS_PYTHON_MEMBER(numFeatures);
0093 ACTS_PYTHON_MEMBER(cut);
0094 ACTS_PYTHON_MEMBER(nChunks);
0095 ACTS_PYTHON_MEMBER(undirected);
0096 ACTS_PYTHON_STRUCT_END();
0097 }
0098 {
0099 using Alg = Acts::BoostTrackBuilding;
0100
0101 auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
0102 mex, "BoostTrackBuilding")
0103 .def(py::init([](Logging::Level lvl) {
0104 return std::make_shared<Alg>(
0105 getDefaultLogger("EdgeClassifier", lvl));
0106 }),
0107 py::arg("level"));
0108 }
0109 #endif
0110
0111 #ifdef ACTS_EXATRKX_ONNX_BACKEND
0112 {
0113 using Alg = Acts::OnnxMetricLearning;
0114 using Config = Alg::Config;
0115
0116 auto alg =
0117 py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0118 mex, "OnnxMetricLearning")
0119 .def(py::init([](const Config &c, Logging::Level lvl) {
0120 return std::make_shared<Alg>(
0121 c, getDefaultLogger("MetricLearning", lvl));
0122 }),
0123 py::arg("config"), py::arg("level"))
0124 .def_property_readonly("config", &Alg::config);
0125
0126 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0127 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0128 ACTS_PYTHON_MEMBER(modelPath);
0129 ACTS_PYTHON_MEMBER(spacepointFeatures);
0130 ACTS_PYTHON_MEMBER(embeddingDim);
0131 ACTS_PYTHON_MEMBER(rVal);
0132 ACTS_PYTHON_MEMBER(knnVal);
0133 ACTS_PYTHON_STRUCT_END();
0134 }
0135 {
0136 using Alg = Acts::OnnxEdgeClassifier;
0137 using Config = Alg::Config;
0138
0139 auto alg =
0140 py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0141 mex, "OnnxEdgeClassifier")
0142 .def(py::init([](const Config &c, Logging::Level lvl) {
0143 return std::make_shared<Alg>(
0144 c, getDefaultLogger("EdgeClassifier", lvl));
0145 }),
0146 py::arg("config"), py::arg("level"))
0147 .def_property_readonly("config", &Alg::config);
0148
0149 auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0150 ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0151 ACTS_PYTHON_MEMBER(modelPath);
0152 ACTS_PYTHON_MEMBER(cut);
0153 ACTS_PYTHON_STRUCT_END();
0154 }
0155 {
0156 using Alg = Acts::CugraphTrackBuilding;
0157
0158 auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
0159 mex, "CugraphTrackBuilding")
0160 .def(py::init([](Logging::Level lvl) {
0161 return std::make_shared<Alg>(
0162 getDefaultLogger("EdgeClassifier", lvl));
0163 }),
0164 py::arg("level"));
0165 }
0166 #endif
0167
0168 ACTS_PYTHON_DECLARE_ALGORITHM(
0169 ActsExamples::TrackFindingAlgorithmExaTrkX, mex,
0170 "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits,
0171 inputParticles, inputClusters, inputMeasurementSimhitsMap,
0172 outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers,
0173 trackBuilder, rScale, phiScale, zScale, cellCountScale, cellSumScale,
0174 clusterXScale, clusterYScale, filterShortTracks, targetMinHits,
0175 targetMinPT);
0176
0177 {
0178 auto cls =
0179 py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
0180 mex, "ExaTrkXHook");
0181 }
0182
0183 {
0184 using Class = Acts::TorchTruthGraphMetricsHook;
0185
0186 auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
0187 mex, "TorchTruthGraphMetricsHook")
0188 .def(py::init(
0189 [](const std::vector<int64_t> &g, Logging::Level lvl) {
0190 return std::make_shared<Class>(
0191 g, getDefaultLogger("PipelineHook", lvl));
0192 }));
0193 }
0194
0195 {
0196 using Class = Acts::ExaTrkXPipeline;
0197
0198 auto cls =
0199 py::class_<Class, std::shared_ptr<Class>>(mex, "ExaTrkXPipeline")
0200 .def(py::init(
0201 [](std::shared_ptr<GraphConstructionBase> g,
0202 std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0203 std::shared_ptr<TrackBuildingBase> t,
0204 Logging::Level lvl) {
0205 return std::make_shared<Class>(
0206 g, e, t, getDefaultLogger("MetricLearning", lvl));
0207 }),
0208 py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0209 py::arg("trackBuilder"), py::arg("level"))
0210 .def("run", &ExaTrkXPipeline::run, py::arg("features"),
0211 py::arg("spacepoints"), py::arg("deviceHint") = -1,
0212 py::arg("hook") = Acts::ExaTrkXHook{},
0213 py::arg("timing") = nullptr);
0214 }
0215
0216 ACTS_PYTHON_DECLARE_ALGORITHM(
0217 ActsExamples::PrototracksToParameters, mex, "PrototracksToParameters",
0218 inputProtoTracks, inputSpacePoints, outputSeeds, outputParameters,
0219 outputProtoTracks, geometry, magneticField, buildTightSeeds);
0220
0221 ACTS_PYTHON_DECLARE_ALGORITHM(
0222 ActsExamples::TrackFindingFromPrototrackAlgorithm, mex,
0223 "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
0224 inputMeasurements, inputSourceLinks, inputInitialTrackParameters,
0225 outputTracks, measurementSelectorCfg, trackingGeometry, magneticField,
0226 findTracks, tag);
0227 }
0228
0229 }