Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
0010 
0011 #include <onnxruntime_cxx_api.h>
0012 #include <torch/script.h>
0013 
0014 #include "runSessionWithIoBinding.hpp"
0015 
0016 using namespace torch::indexing;
0017 
0018 namespace Acts {
0019 
0020 OnnxEdgeClassifier::OnnxEdgeClassifier(const Config &cfg,
0021                                        std::unique_ptr<const Logger> logger)
0022     : m_logger(std::move(logger)), m_cfg(cfg) {
0023   m_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING,
0024                                      "ExaTrkX - edge classifier");
0025 
0026   Ort::SessionOptions session_options;
0027   session_options.SetIntraOpNumThreads(1);
0028   session_options.SetGraphOptimizationLevel(
0029       GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0030 
0031   m_model = std::make_unique<Ort::Session>(*m_env, m_cfg.modelPath.c_str(),
0032                                            session_options);
0033 
0034   Ort::AllocatorWithDefaultOptions allocator;
0035 
0036   m_inputNameNodes =
0037       std::string(m_model->GetInputNameAllocated(0, allocator).get());
0038   m_inputNameEdges =
0039       std::string(m_model->GetInputNameAllocated(1, allocator).get());
0040   m_outputNameScores =
0041       std::string(m_model->GetOutputNameAllocated(0, allocator).get());
0042 }
0043 
0044 OnnxEdgeClassifier::~OnnxEdgeClassifier() {}
0045 
0046 std::tuple<std::any, std::any, std::any> OnnxEdgeClassifier::operator()(
0047     std::any inputNodes, std::any inputEdges, int) {
0048   Ort::AllocatorWithDefaultOptions allocator;
0049   auto memoryInfo = Ort::MemoryInfo::CreateCpu(
0050       OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0051 
0052   auto eInputTensor = std::any_cast<std::shared_ptr<Ort::Value>>(inputNodes);
0053   auto edgeList = std::any_cast<std::vector<int64_t>>(inputEdges);
0054   const int numEdges = edgeList.size() / 2;
0055 
0056   std::vector<const char *> fInputNames{m_inputNameNodes.c_str(),
0057                                         m_inputNameEdges.c_str()};
0058   std::vector<Ort::Value> fInputTensor;
0059   fInputTensor.push_back(std::move(*eInputTensor));
0060   std::vector<int64_t> fEdgeShape{2, numEdges};
0061   fInputTensor.push_back(Ort::Value::CreateTensor<int64_t>(
0062       memoryInfo, edgeList.data(), edgeList.size(), fEdgeShape.data(),
0063       fEdgeShape.size()));
0064 
0065   // filtering outputs
0066   std::vector<const char *> fOutputNames{m_outputNameScores.c_str()};
0067   std::vector<float> fOutputData(numEdges);
0068 
0069   auto outputDims = m_model->GetOutputTypeInfo(0)
0070                         .GetTensorTypeAndShapeInfo()
0071                         .GetDimensionsCount();
0072   using Shape = std::vector<int64_t>;
0073   Shape fOutputShape = outputDims == 2 ? Shape{numEdges, 1} : Shape{numEdges};
0074   std::vector<Ort::Value> fOutputTensor;
0075   fOutputTensor.push_back(Ort::Value::CreateTensor<float>(
0076       memoryInfo, fOutputData.data(), fOutputData.size(), fOutputShape.data(),
0077       fOutputShape.size()));
0078   runSessionWithIoBinding(*m_model, fInputNames, fInputTensor, fOutputNames,
0079                           fOutputTensor);
0080 
0081   ACTS_DEBUG("Get scores for " << numEdges << " edges.");
0082   torch::Tensor edgeListCTen = torch::tensor(edgeList, {torch::kInt64});
0083   edgeListCTen = edgeListCTen.reshape({2, numEdges});
0084 
0085   torch::Tensor fOutputCTen = torch::tensor(fOutputData, {torch::kFloat32});
0086   fOutputCTen = fOutputCTen.sigmoid();
0087 
0088   torch::Tensor filterMask = fOutputCTen > m_cfg.cut;
0089   torch::Tensor edgesAfterFCTen = edgeListCTen.index({Slice(), filterMask});
0090 
0091   std::vector<int64_t> edgesAfterFiltering;
0092   std::copy(edgesAfterFCTen.data_ptr<int64_t>(),
0093             edgesAfterFCTen.data_ptr<int64_t>() + edgesAfterFCTen.numel(),
0094             std::back_inserter(edgesAfterFiltering));
0095 
0096   int64_t numEdgesAfterF = edgesAfterFiltering.size() / 2;
0097   ACTS_DEBUG("Finished edge classification, after cut: " << numEdgesAfterF
0098                                                          << " edges.");
0099 
0100   return {std::make_shared<Ort::Value>(std::move(fInputTensor[0])),
0101           edgesAfterFiltering, fOutputCTen};
0102 }
0103 
0104 }  // namespace Acts