File indexing completed on 2025-08-05 08:10:17
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0010
0011 #include <onnxruntime_cxx_api.h>
0012 #include <torch/script.h>
0013
0014 #include "runSessionWithIoBinding.hpp"
0015
0016 using namespace torch::indexing;
0017
0018 namespace Acts {
0019
0020 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0021 std::unique_ptr<const Logger> logger)
0022 : m_logger(std::move(logger)), m_cfg(cfg) {
0023 m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
0024 "ExaTrkX - edge classifier");
0025
0026 Ort::SessionOptions session_options;
0027 session_options.SetIntraOpNumThreads(1);
0028 session_options.SetGraphOptimizationLevel(
0029 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0030
0031 m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0032 session_options);
0033
0034 Ort::AllocatorWithDefaultOptions allocator;
0035
0036 m_inputNameNodes =
0037 std::string(m_model->GetInputNameAllocated(0, allocator).get());
0038 m_inputNameEdges =
0039 std::string(m_model->GetInputNameAllocated(1, allocator).get());
0040 m_outputNameScores =
0041 std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0042 }
0043
0044 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0045
0046 std::tuple<std::any, std::any, std::any> OnnxEdgeClassifier::operator()(
0047 std::any inputNodes, std::any inputEdges, int) {
0048 Ort::AllocatorWithDefaultOptions allocator;
0049 auto memoryInfo = Ort::MemoryInfo::CreateCpu(
0050 OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0051
0052 auto eInputTensor = std::any_cast<std::shared_ptr<Ort::Value>>(inputNodes);
0053 auto edgeList = std::any_cast<std::vector<int64_t>>(inputEdges);
0054 const int numEdges = edgeList.size() / 2;
0055
0056 std::vector<const char *> fInputNames{m_inputNameNodes.c_str(),
0057 m_inputNameEdges.c_str()};
0058 std::vector<Ort::Value> fInputTensor;
0059 fInputTensor.push_back(std::move(*eInputTensor));
0060 std::vector<int64_t> fEdgeShape{2, numEdges};
0061 fInputTensor.push_back(Ort::Value::CreateTensor<int64_t>(
0062 memoryInfo, edgeList.data(), edgeList.size(), fEdgeShape.data(),
0063 fEdgeShape.size()));
0064
0065
0066 std::vector<const char *> fOutputNames{m_outputNameScores.c_str()};
0067 std::vector<float> fOutputData(numEdges);
0068
0069 auto outputDims = m_model->GetOutputTypeInfo(0)
0070 .GetTensorTypeAndShapeInfo()
0071 .GetDimensionsCount();
0072 using Shape = std::vector<int64_t>;
0073 Shape fOutputShape = outputDims == 2 ? Shape{numEdges, 1} : Shape{numEdges};
0074 std::vector<Ort::Value> fOutputTensor;
0075 fOutputTensor.push_back(Ort::Value::CreateTensor<float>(
0076 memoryInfo, fOutputData.data(), fOutputData.size(), fOutputShape.data(),
0077 fOutputShape.size()));
0078 runSessionWithIoBinding(*m_model, fInputNames, fInputTensor, fOutputNames,
0079 fOutputTensor);
0080
0081 ACTS_DEBUG("Get scores for " << numEdges << " edges.");
0082 torch::Tensor edgeListCTen = torch::tensor(edgeList, {torch::kInt64});
0083 edgeListCTen = edgeListCTen.reshape({2, numEdges});
0084
0085 torch::Tensor fOutputCTen = torch::tensor(fOutputData, {torch::kFloat32});
0086 fOutputCTen = fOutputCTen.sigmoid();
0087
0088 torch::Tensor filterMask = fOutputCTen > m_cfg.cut;
0089 torch::Tensor edgesAfterFCTen = edgeListCTen.index({Slice(), filterMask});
0090
0091 std::vector<int64_t> edgesAfterFiltering;
0092 std::copy(edgesAfterFCTen.data_ptr<int64_t>(),
0093 edgesAfterFCTen.data_ptr<int64_t>() + edgesAfterFCTen.numel(),
0094 std::back_inserter(edgesAfterFiltering));
0095
0096 int64_t numEdgesAfterF = edgesAfterFiltering.size() / 2;
0097 ACTS_DEBUG("Finished edge classification, after cut: " << numEdgesAfterF
0098 << " edges.");
0099
0100 return {std::make_shared<Ort::Value>(std::move(fInputTensor[0])),
0101 edgesAfterFiltering, fOutputCTen};
0102 }
0103
0104 }