Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:04

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2021 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp"
0010 #include "Acts/Plugins/ExaTrkX/CugraphTrackBuilding.hpp"
0011 #include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
0012 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0013 #include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
0014 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0015 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0016 #include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
0017 #include "Acts/Plugins/Python/Utilities.hpp"
0018 #include "ActsExamples/TrackFindingExaTrkX/PrototracksToParameters.hpp"
0019 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingAlgorithmExaTrkX.hpp"
0020 #include "ActsExamples/TrackFindingExaTrkX/TrackFindingFromPrototrackAlgorithm.hpp"
0021 
0022 #include <memory>
0023 
0024 #include <pybind11/functional.h>
0025 #include <pybind11/pybind11.h>
0026 #include <pybind11/stl.h>
0027 
0028 namespace py = pybind11;
0029 
0030 using namespace ActsExamples;
0031 using namespace Acts;
0032 
0033 namespace Acts::Python {
0034 
0035 void addExaTrkXTrackFinding(Context &ctx) {
0036   auto [m, mex] = ctx.get("main", "examples");
0037 
0038   {
0039     using C = Acts::GraphConstructionBase;
0040     auto c = py::class_<C, std::shared_ptr<C>>(mex, "GraphConstructionBase");
0041   }
0042   {
0043     using C = Acts::EdgeClassificationBase;
0044     auto c = py::class_<C, std::shared_ptr<C>>(mex, "EdgeClassificationBase");
0045   }
0046   {
0047     using C = Acts::TrackBuildingBase;
0048     auto c = py::class_<C, std::shared_ptr<C>>(mex, "TrackBuildingBase");
0049   }
0050 
0051 #ifdef ACTS_EXATRKX_TORCH_BACKEND
0052   {
0053     using Alg = Acts::TorchMetricLearning;
0054     using Config = Alg::Config;
0055 
0056     auto alg =
0057         py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0058             mex, "TorchMetricLearning")
0059             .def(py::init([](const Config &c, Logging::Level lvl) {
0060                    return std::make_shared<Alg>(
0061                        c, getDefaultLogger("MetricLearning", lvl));
0062                  }),
0063                  py::arg("config"), py::arg("level"))
0064             .def_property_readonly("config", &Alg::config);
0065 
0066     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0067     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0068     ACTS_PYTHON_MEMBER(modelPath);
0069     ACTS_PYTHON_MEMBER(numFeatures);
0070     ACTS_PYTHON_MEMBER(embeddingDim);
0071     ACTS_PYTHON_MEMBER(rVal);
0072     ACTS_PYTHON_MEMBER(knnVal);
0073     ACTS_PYTHON_STRUCT_END();
0074   }
0075   {
0076     using Alg = Acts::TorchEdgeClassifier;
0077     using Config = Alg::Config;
0078 
0079     auto alg =
0080         py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0081             mex, "TorchEdgeClassifier")
0082             .def(py::init([](const Config &c, Logging::Level lvl) {
0083                    return std::make_shared<Alg>(
0084                        c, getDefaultLogger("EdgeClassifier", lvl));
0085                  }),
0086                  py::arg("config"), py::arg("level"))
0087             .def_property_readonly("config", &Alg::config);
0088 
0089     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0090     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0091     ACTS_PYTHON_MEMBER(modelPath);
0092     ACTS_PYTHON_MEMBER(numFeatures);
0093     ACTS_PYTHON_MEMBER(cut);
0094     ACTS_PYTHON_MEMBER(nChunks);
0095     ACTS_PYTHON_MEMBER(undirected);
0096     ACTS_PYTHON_STRUCT_END();
0097   }
0098   {
0099     using Alg = Acts::BoostTrackBuilding;
0100 
0101     auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
0102                    mex, "BoostTrackBuilding")
0103                    .def(py::init([](Logging::Level lvl) {
0104                           return std::make_shared<Alg>(
0105                               getDefaultLogger("EdgeClassifier", lvl));
0106                         }),
0107                         py::arg("level"));
0108   }
0109 #endif
0110 
0111 #ifdef ACTS_EXATRKX_ONNX_BACKEND
0112   {
0113     using Alg = Acts::OnnxMetricLearning;
0114     using Config = Alg::Config;
0115 
0116     auto alg =
0117         py::class_<Alg, Acts::GraphConstructionBase, std::shared_ptr<Alg>>(
0118             mex, "OnnxMetricLearning")
0119             .def(py::init([](const Config &c, Logging::Level lvl) {
0120                    return std::make_shared<Alg>(
0121                        c, getDefaultLogger("MetricLearning", lvl));
0122                  }),
0123                  py::arg("config"), py::arg("level"))
0124             .def_property_readonly("config", &Alg::config);
0125 
0126     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0127     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0128     ACTS_PYTHON_MEMBER(modelPath);
0129     ACTS_PYTHON_MEMBER(spacepointFeatures);
0130     ACTS_PYTHON_MEMBER(embeddingDim);
0131     ACTS_PYTHON_MEMBER(rVal);
0132     ACTS_PYTHON_MEMBER(knnVal);
0133     ACTS_PYTHON_STRUCT_END();
0134   }
0135   {
0136     using Alg = Acts::OnnxEdgeClassifier;
0137     using Config = Alg::Config;
0138 
0139     auto alg =
0140         py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
0141             mex, "OnnxEdgeClassifier")
0142             .def(py::init([](const Config &c, Logging::Level lvl) {
0143                    return std::make_shared<Alg>(
0144                        c, getDefaultLogger("EdgeClassifier", lvl));
0145                  }),
0146                  py::arg("config"), py::arg("level"))
0147             .def_property_readonly("config", &Alg::config);
0148 
0149     auto c = py::class_<Config>(alg, "Config").def(py::init<>());
0150     ACTS_PYTHON_STRUCT_BEGIN(c, Config);
0151     ACTS_PYTHON_MEMBER(modelPath);
0152     ACTS_PYTHON_MEMBER(cut);
0153     ACTS_PYTHON_STRUCT_END();
0154   }
0155   {
0156     using Alg = Acts::CugraphTrackBuilding;
0157 
0158     auto alg = py::class_<Alg, Acts::TrackBuildingBase, std::shared_ptr<Alg>>(
0159                    mex, "CugraphTrackBuilding")
0160                    .def(py::init([](Logging::Level lvl) {
0161                           return std::make_shared<Alg>(
0162                               getDefaultLogger("EdgeClassifier", lvl));
0163                         }),
0164                         py::arg("level"));
0165   }
0166 #endif
0167 
0168   ACTS_PYTHON_DECLARE_ALGORITHM(
0169       ActsExamples::TrackFindingAlgorithmExaTrkX, mex,
0170       "TrackFindingAlgorithmExaTrkX", inputSpacePoints, inputSimHits,
0171       inputParticles, inputClusters, inputMeasurementSimhitsMap,
0172       outputProtoTracks, outputGraph, graphConstructor, edgeClassifiers,
0173       trackBuilder, rScale, phiScale, zScale, cellCountScale, cellSumScale,
0174       clusterXScale, clusterYScale, filterShortTracks, targetMinHits,
0175       targetMinPT);
0176 
0177   {
0178     auto cls =
0179         py::class_<Acts::ExaTrkXHook, std::shared_ptr<Acts::ExaTrkXHook>>(
0180             mex, "ExaTrkXHook");
0181   }
0182 
0183   {
0184     using Class = Acts::TorchTruthGraphMetricsHook;
0185 
0186     auto cls = py::class_<Class, Acts::ExaTrkXHook, std::shared_ptr<Class>>(
0187                    mex, "TorchTruthGraphMetricsHook")
0188                    .def(py::init(
0189                        [](const std::vector<int64_t> &g, Logging::Level lvl) {
0190                          return std::make_shared<Class>(
0191                              g, getDefaultLogger("PipelineHook", lvl));
0192                        }));
0193   }
0194 
0195   {
0196     using Class = Acts::ExaTrkXPipeline;
0197 
0198     auto cls =
0199         py::class_<Class, std::shared_ptr<Class>>(mex, "ExaTrkXPipeline")
0200             .def(py::init(
0201                      [](std::shared_ptr<GraphConstructionBase> g,
0202                         std::vector<std::shared_ptr<EdgeClassificationBase>> e,
0203                         std::shared_ptr<TrackBuildingBase> t,
0204                         Logging::Level lvl) {
0205                        return std::make_shared<Class>(
0206                            g, e, t, getDefaultLogger("MetricLearning", lvl));
0207                      }),
0208                  py::arg("graphConstructor"), py::arg("edgeClassifiers"),
0209                  py::arg("trackBuilder"), py::arg("level"))
0210             .def("run", &ExaTrkXPipeline::run, py::arg("features"),
0211                  py::arg("spacepoints"), py::arg("deviceHint") = -1,
0212                  py::arg("hook") = Acts::ExaTrkXHook{},
0213                  py::arg("timing") = nullptr);
0214   }
0215 
0216   ACTS_PYTHON_DECLARE_ALGORITHM(
0217       ActsExamples::PrototracksToParameters, mex, "PrototracksToParameters",
0218       inputProtoTracks, inputSpacePoints, outputSeeds, outputParameters,
0219       outputProtoTracks, geometry, magneticField, buildTightSeeds);
0220 
0221   ACTS_PYTHON_DECLARE_ALGORITHM(
0222       ActsExamples::TrackFindingFromPrototrackAlgorithm, mex,
0223       "TrackFindingFromPrototrackAlgorithm", inputProtoTracks,
0224       inputMeasurements, inputSourceLinks, inputInitialTrackParameters,
0225       outputTracks, measurementSelectorCfg, trackingGeometry, magneticField,
0226       findTracks, tag);
0227 }
0228 
0229 }  // namespace Acts::Python