Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0010 
0011 #include <torch/script.h>
0012 #include <torch/torch.h>
0013 
0014 #include "printCudaMemInfo.hpp"
0015 
0016 using namespace torch::indexing;
0017 
0018 namespace Acts {
0019 
0020 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
0021                                          std::unique_ptr<const Logger> _logger)
0022     : m_logger(std::move(_logger)), m_cfg(cfg) {
0023   c10::InferenceMode guard(true);
0024   m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0025   ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0026                                     << TORCH_VERSION_MINOR << "."
0027                                     << TORCH_VERSION_PATCH);
0028 #ifndef ACTS_EXATRKX_CPUONLY
0029   if (not torch::cuda::is_available()) {
0030     ACTS_INFO("CUDA not available, falling back to CPU");
0031   }
0032 #endif
0033 
0034   try {
0035     m_model = std::make_unique<torch::jit::Module>();
0036     *m_model = torch::jit::load(m_cfg.modelPath.c_str(), m_deviceType);
0037     m_model->eval();
0038   } catch (const c10::Error& e) {
0039     throw std::invalid_argument("Failed to load models: " + e.msg());
0040   }
0041 }
0042 
0043 TorchEdgeClassifier::~TorchEdgeClassifier() {}
0044 
0045 std::tuple<std::any, std::any, std::any> TorchEdgeClassifier::operator()(
0046     std::any inputNodes, std::any inputEdges, int deviceHint) {
0047   ACTS_DEBUG("Start edge classification");
0048   c10::InferenceMode guard(true);
0049   const torch::Device device(m_deviceType, deviceHint);
0050 
0051   auto nodes = std::any_cast<torch::Tensor>(inputNodes).to(device);
0052   auto edgeList = std::any_cast<torch::Tensor>(inputEdges).to(device);
0053 
0054   auto model = m_model->clone();
0055   model.to(device);
0056 
0057   if (m_cfg.numFeatures > nodes.size(1)) {
0058     throw std::runtime_error("requested more features then available");
0059   }
0060 
0061   torch::Tensor output;
0062 
0063   // Scope this to keep inference objects separate
0064   {
0065     auto edgeListTmp = m_cfg.undirected
0066                            ? torch::cat({edgeList, edgeList.flip(0)}, 1)
0067                            : edgeList;
0068 
0069     std::vector<torch::jit::IValue> inputTensors(2);
0070     inputTensors[0] =
0071         m_cfg.numFeatures < nodes.size(1)
0072             ? nodes.index({Slice{}, Slice{None, m_cfg.numFeatures}})
0073             : nodes;
0074 
0075     if (m_cfg.nChunks > 1) {
0076       std::vector<at::Tensor> results;
0077       results.reserve(m_cfg.nChunks);
0078 
0079       auto chunks = at::chunk(edgeListTmp, m_cfg.nChunks, 1);
0080       for (auto& chunk : chunks) {
0081         ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
0082         inputTensors[1] = chunk;
0083 
0084         results.push_back(model.forward(inputTensors).toTensor());
0085         results.back().squeeze_();
0086       }
0087 
0088       output = torch::cat(results);
0089     } else {
0090       inputTensors[1] = edgeListTmp;
0091       output = model.forward(inputTensors).toTensor();
0092       output.squeeze_();
0093     }
0094   }
0095 
0096   output.sigmoid_();
0097 
0098   if (m_cfg.undirected) {
0099     auto newSize = output.size(0) / 2;
0100     output = output.index({Slice(None, newSize)});
0101   }
0102 
0103   ACTS_VERBOSE("Size after classifier: " << output.size(0));
0104   ACTS_VERBOSE("Slice of classified output:\n"
0105                << output.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
0106   printCudaMemInfo(logger());
0107 
0108   torch::Tensor mask = output > m_cfg.cut;
0109   torch::Tensor edgesAfterCut = edgeList.index({Slice(), mask});
0110   edgesAfterCut = edgesAfterCut.to(torch::kInt64);
0111 
0112   ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0113   printCudaMemInfo(logger());
0114 
0115   return {std::move(nodes), std::move(edgesAfterCut),
0116           output.masked_select(mask)};
0117 }
0118 
0119 }  // namespace Acts