Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2022 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0012 
0013 #include <onnxruntime_cxx_api.h>
0014 #include <torch/script.h>
0015 
0016 #include "runSessionWithIoBinding.hpp"
0017 
0018 namespace Acts {
0019 
0020 OnnxMetricLearning::OnnxMetricLearning(const Config& cfg,
0021                                        std::unique_ptr<const Logger> logger)
0022     : m_logger(std::move(logger)), m_cfg(cfg) {
0023   m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
0024                                      "ExaTrkX - metric learning");
0025 
0026   Ort::SessionOptions session_options;
0027   session_options.SetIntraOpNumThreads(1);
0028   session_options.SetGraphOptimizationLevel(
0029       GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0030 
0031   m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0032                                            session_options);
0033 }
0034 
0035 OnnxMetricLearning::~OnnxMetricLearning() {}
0036 
0037 void OnnxMetricLearning::buildEdgesWrapper(std::vector<float>& embedFeatures,
0038                                            std::vector<int64_t>& edgeList,
0039                                            int64_t numSpacepoints,
0040                                            const Logger& logger) const {
0041   torch::Device device(torch::kCUDA);
0042   auto options =
0043       torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
0044 
0045   torch::Tensor embedTensor =
0046       torch::tensor(embedFeatures, options)
0047           .reshape({numSpacepoints, m_cfg.embeddingDim});
0048 
0049   auto stackedEdges = detail::buildEdges(embedTensor, m_cfg.rVal, m_cfg.knnVal);
0050 
0051   stackedEdges = stackedEdges.toType(torch::kInt64).to(torch::kCPU);
0052 
0053   ACTS_VERBOSE("copy edges to std::vector");
0054   std::copy(stackedEdges.data_ptr<int64_t>(),
0055             stackedEdges.data_ptr<int64_t>() + stackedEdges.numel(),
0056             std::back_inserter(edgeList));
0057 }
0058 
0059 std::tuple<std::any, std::any> OnnxMetricLearning::operator()(
0060     std::vector<float>& inputValues, std::size_t, int) {
0061   Ort::AllocatorWithDefaultOptions allocator;
0062   auto memoryInfo = Ort::MemoryInfo::CreateCpu(
0063       OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0064 
0065   // ************
0066   // Embedding
0067   // ************
0068 
0069   int64_t numSpacepoints = inputValues.size() / m_cfg.spacepointFeatures;
0070   std::vector<int64_t> eInputShape{numSpacepoints, m_cfg.spacepointFeatures};
0071 
0072   std::vector<const char*> eInputNames{"sp_features"};
0073   std::vector<Ort::Value> eInputTensor;
0074   eInputTensor.push_back(Ort::Value::CreateTensor<float>(
0075       memoryInfo, inputValues.data(), inputValues.size(), eInputShape.data(),
0076       eInputShape.size()));
0077 
0078   std::vector<float> eOutputData(numSpacepoints * m_cfg.embeddingDim);
0079   std::vector<const char*> eOutputNames{"embedding_output"};
0080   std::vector<int64_t> eOutputShape{numSpacepoints, m_cfg.embeddingDim};
0081   std::vector<Ort::Value> eOutputTensor;
0082   eOutputTensor.push_back(Ort::Value::CreateTensor<float>(
0083       memoryInfo, eOutputData.data(), eOutputData.size(), eOutputShape.data(),
0084       eOutputShape.size()));
0085   runSessionWithIoBinding(*m_model, eInputNames, eInputTensor, eOutputNames,
0086                           eOutputTensor);
0087 
0088   ACTS_VERBOSE("Embedding space of the first SP: ");
0089   for (std::size_t i = 0; i < 3; i++) {
0090     ACTS_VERBOSE("\t" << eOutputData[i]);
0091   }
0092 
0093   // ************
0094   // Building Edges
0095   // ************
0096   std::vector<int64_t> edgeList;
0097   buildEdgesWrapper(eOutputData, edgeList, numSpacepoints, logger());
0098   int64_t numEdges = edgeList.size() / 2;
0099   ACTS_DEBUG("Graph construction: built " << numEdges << " edges.");
0100 
0101   for (std::size_t i = 0; i < 10; i++) {
0102     ACTS_VERBOSE(edgeList[i]);
0103   }
0104   for (std::size_t i = 0; i < 10; i++) {
0105     ACTS_VERBOSE(edgeList[numEdges + i]);
0106   }
0107 
0108   return {std::make_shared<Ort::Value>(std::move(eInputTensor[0])), edgeList};
0109 }
0110 
0111 }  // namespace Acts