File indexing completed on 2025-08-05 08:10:17
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
0010
0011 #include <torch/script.h>
0012 #include <torch/torch.h>
0013
0014 #include "printCudaMemInfo.hpp"
0015
0016 using namespace torch::indexing;
0017
0018 namespace Acts {
0019
0020 TorchEdgeClassifier::TorchEdgeClassifier(const Config& cfg,
0021 std::unique_ptr<const Logger> _logger)
0022 : m_logger(std::move(_logger)), m_cfg(cfg) {
0023 c10::InferenceMode guard(true);
0024 m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0025 ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0026 << TORCH_VERSION_MINOR << "."
0027 << TORCH_VERSION_PATCH);
0028 #ifndef ACTS_EXATRKX_CPUONLY
0029 if (not torch::cuda::is_available()) {
0030 ACTS_INFO("CUDA not available, falling back to CPU");
0031 }
0032 #endif
0033
0034 try {
0035 m_model = std::make_unique<torch::jit::Module>();
0036 *m_model = torch::jit::load(m_cfg.modelPath.c_str(), m_deviceType);
0037 m_model->eval();
0038 } catch (const c10::Error& e) {
0039 throw std::invalid_argument("Failed to load models: " + e.msg());
0040 }
0041 }
0042
0043 TorchEdgeClassifier::~TorchEdgeClassifier() {}
0044
0045 std::tuple<std::any, std::any, std::any> TorchEdgeClassifier::operator()(
0046 std::any inputNodes, std::any inputEdges, int deviceHint) {
0047 ACTS_DEBUG("Start edge classification");
0048 c10::InferenceMode guard(true);
0049 const torch::Device device(m_deviceType, deviceHint);
0050
0051 auto nodes = std::any_cast<torch::Tensor>(inputNodes).to(device);
0052 auto edgeList = std::any_cast<torch::Tensor>(inputEdges).to(device);
0053
0054 auto model = m_model->clone();
0055 model.to(device);
0056
0057 if (m_cfg.numFeatures > nodes.size(1)) {
0058 throw std::runtime_error("requested more features then available");
0059 }
0060
0061 torch::Tensor output;
0062
0063
0064 {
0065 auto edgeListTmp = m_cfg.undirected
0066 ? torch::cat({edgeList, edgeList.flip(0)}, 1)
0067 : edgeList;
0068
0069 std::vector<torch::jit::IValue> inputTensors(2);
0070 inputTensors[0] =
0071 m_cfg.numFeatures < nodes.size(1)
0072 ? nodes.index({Slice{}, Slice{None, m_cfg.numFeatures}})
0073 : nodes;
0074
0075 if (m_cfg.nChunks > 1) {
0076 std::vector<at::Tensor> results;
0077 results.reserve(m_cfg.nChunks);
0078
0079 auto chunks = at::chunk(edgeListTmp, m_cfg.nChunks, 1);
0080 for (auto& chunk : chunks) {
0081 ACTS_VERBOSE("Process chunk with shape" << chunk.sizes());
0082 inputTensors[1] = chunk;
0083
0084 results.push_back(model.forward(inputTensors).toTensor());
0085 results.back().squeeze_();
0086 }
0087
0088 output = torch::cat(results);
0089 } else {
0090 inputTensors[1] = edgeListTmp;
0091 output = model.forward(inputTensors).toTensor();
0092 output.squeeze_();
0093 }
0094 }
0095
0096 output.sigmoid_();
0097
0098 if (m_cfg.undirected) {
0099 auto newSize = output.size(0) / 2;
0100 output = output.index({Slice(None, newSize)});
0101 }
0102
0103 ACTS_VERBOSE("Size after classifier: " << output.size(0));
0104 ACTS_VERBOSE("Slice of classified output:\n"
0105 << output.slice(0, 0, 9));
0106 printCudaMemInfo(logger());
0107
0108 torch::Tensor mask = output > m_cfg.cut;
0109 torch::Tensor edgesAfterCut = edgeList.index({Slice(), mask});
0110 edgesAfterCut = edgesAfterCut.to(torch::kInt64);
0111
0112 ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
0113 printCudaMemInfo(logger());
0114
0115 return {std::move(nodes), std::move(edgesAfterCut),
0116 output.masked_select(mask)};
0117 }
0118
0119 }