File indexing completed on 2025-08-05 08:10:17
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0012
0013 #include <onnxruntime_cxx_api.h>
0014 #include <torch/script.h>
0015
0016 #include "runSessionWithIoBinding.hpp"
0017
0018 namespace Acts {
0019
0020 OnnxMetricLearning::OnnxMetricLearning(const Config& cfg,
0021 std::unique_ptr<const Logger> logger)
0022 : m_logger(std::move(logger)), m_cfg(cfg) {
0023 m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
0024 "ExaTrkX - metric learning");
0025
0026 Ort::SessionOptions session_options;
0027 session_options.SetIntraOpNumThreads(1);
0028 session_options.SetGraphOptimizationLevel(
0029 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0030
0031 m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0032 session_options);
0033 }
0034
0035 OnnxMetricLearning::~OnnxMetricLearning() {}
0036
0037 void OnnxMetricLearning::buildEdgesWrapper(std::vector<float>& embedFeatures,
0038 std::vector<int64_t>& edgeList,
0039 int64_t numSpacepoints,
0040 const Logger& logger) const {
0041 torch::Device device(torch::kCUDA);
0042 auto options =
0043 torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
0044
0045 torch::Tensor embedTensor =
0046 torch::tensor(embedFeatures, options)
0047 .reshape({numSpacepoints, m_cfg.embeddingDim});
0048
0049 auto stackedEdges = detail::buildEdges(embedTensor, m_cfg.rVal, m_cfg.knnVal);
0050
0051 stackedEdges = stackedEdges.toType(torch::kInt64).to(torch::kCPU);
0052
0053 ACTS_VERBOSE("copy edges to std::vector");
0054 std::copy(stackedEdges.data_ptr<int64_t>(),
0055 stackedEdges.data_ptr<int64_t>() + stackedEdges.numel(),
0056 std::back_inserter(edgeList));
0057 }
0058
0059 std::tuple<std::any, std::any> OnnxMetricLearning::operator()(
0060 std::vector<float>& inputValues, std::size_t, int) {
0061 Ort::AllocatorWithDefaultOptions allocator;
0062 auto memoryInfo = Ort::MemoryInfo::CreateCpu(
0063 OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0064
0065
0066
0067
0068
0069 int64_t numSpacepoints = inputValues.size() / m_cfg.spacepointFeatures;
0070 std::vector<int64_t> eInputShape{numSpacepoints, m_cfg.spacepointFeatures};
0071
0072 std::vector<const char*> eInputNames{"sp_features"};
0073 std::vector<Ort::Value> eInputTensor;
0074 eInputTensor.push_back(Ort::Value::CreateTensor<float>(
0075 memoryInfo, inputValues.data(), inputValues.size(), eInputShape.data(),
0076 eInputShape.size()));
0077
0078 std::vector<float> eOutputData(numSpacepoints * m_cfg.embeddingDim);
0079 std::vector<const char*> eOutputNames{"embedding_output"};
0080 std::vector<int64_t> eOutputShape{numSpacepoints, m_cfg.embeddingDim};
0081 std::vector<Ort::Value> eOutputTensor;
0082 eOutputTensor.push_back(Ort::Value::CreateTensor<float>(
0083 memoryInfo, eOutputData.data(), eOutputData.size(), eOutputShape.data(),
0084 eOutputShape.size()));
0085 runSessionWithIoBinding(*m_model, eInputNames, eInputTensor, eOutputNames,
0086 eOutputTensor);
0087
0088 ACTS_VERBOSE("Embedding space of the first SP: ");
0089 for (std::size_t i = 0; i < 3; i++) {
0090 ACTS_VERBOSE("\t" << eOutputData[i]);
0091 }
0092
0093
0094
0095
0096 std::vector<int64_t> edgeList;
0097 buildEdgesWrapper(eOutputData, edgeList, numSpacepoints, logger());
0098 int64_t numEdges = edgeList.size() / 2;
0099 ACTS_DEBUG("Graph construction: built " << numEdges << " edges.");
0100
0101 for (std::size_t i = 0; i < 10; i++) {
0102 ACTS_VERBOSE(edgeList[i]);
0103 }
0104 for (std::size_t i = 0; i < 10; i++) {
0105 ACTS_VERBOSE(edgeList[numEdges + i]);
0106 }
0107
0108 return {std::make_shared<Ort::Value>(std::move(eInputTensor[0])), edgeList};
0109 }
0110
0111 }