Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2023 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0010 
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0013 
0014 #include <torch/script.h>
0015 #include <torch/torch.h>
0016 
0017 #include "printCudaMemInfo.hpp"
0018 
0019 using namespace torch::indexing;
0020 
0021 namespace Acts {
0022 
0023 TorchMetricLearning::TorchMetricLearning(const Config &cfg,
0024                                          std::unique_ptr<const Logger> _logger)
0025     : m_logger(std::move(_logger)), m_cfg(cfg) {
0026   c10::InferenceMode guard(true);
0027   m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0028   ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0029                                     << TORCH_VERSION_MINOR << "."
0030                                     << TORCH_VERSION_PATCH);
0031 #ifndef ACTS_EXATRKX_CPUONLY
0032   if (not torch::cuda::is_available()) {
0033     ACTS_INFO("CUDA not available, falling back to CPU");
0034   }
0035 #endif
0036 
0037   try {
0038     m_model = std::make_unique<torch::jit::Module>();
0039     *m_model = torch::jit::load(m_cfg.modelPath, m_deviceType);
0040     m_model->eval();
0041   } catch (const c10::Error &e) {
0042     throw std::invalid_argument("Failed to load models: " + e.msg());
0043   }
0044 }
0045 
0046 TorchMetricLearning::~TorchMetricLearning() {}
0047 
0048 std::tuple<std::any, std::any> TorchMetricLearning::operator()(
0049     std::vector<float> &inputValues, std::size_t numNodes, int deviceHint) {
0050   ACTS_DEBUG("Start graph construction");
0051   c10::InferenceMode guard(true);
0052   const torch::Device device(m_deviceType, deviceHint);
0053 
0054   const int64_t numAllFeatures = inputValues.size() / numNodes;
0055 
0056   // printout the r,phi,z of the first spacepoint
0057   ACTS_VERBOSE("First spacepoint information: " << [&]() {
0058     std::stringstream ss;
0059     for (int i = 0; i < numAllFeatures; ++i) {
0060       ss << inputValues[i] << "  ";
0061     }
0062     return ss.str();
0063   }());
0064   printCudaMemInfo(logger());
0065 
0066   auto inputTensor = detail::vectorToTensor2D(inputValues, numAllFeatures);
0067 
0068   // If we are on CPU, clone to get ownership (is this necessary?), else bring
0069   // to device.
0070   if (inputTensor.options().device() == device) {
0071     inputTensor = inputTensor.clone();
0072   } else {
0073     inputTensor = inputTensor.to(device);
0074   }
0075 
0076   // **********
0077   // Embedding
0078   // **********
0079 
0080   if (m_cfg.numFeatures > numAllFeatures) {
0081     throw std::runtime_error("requested more features then available");
0082   }
0083 
0084   // Clone models (solve memory leak? members can be const...)
0085   auto model = m_model->clone();
0086   model.to(device);
0087 
0088   std::vector<torch::jit::IValue> inputTensors;
0089   inputTensors.push_back(
0090       m_cfg.numFeatures < numAllFeatures
0091           ? inputTensor.index({Slice{}, Slice{None, m_cfg.numFeatures}})
0092           : std::move(inputTensor));
0093 
0094   ACTS_DEBUG("embedding input tensor shape "
0095              << inputTensors[0].toTensor().size(0) << ", "
0096              << inputTensors[0].toTensor().size(1));
0097 
0098   auto output = model.forward(inputTensors).toTensor();
0099 
0100   ACTS_VERBOSE("Embedding space of the first SP:\n"
0101                << output.slice(/*dim=*/0, /*start=*/0, /*end=*/1));
0102   printCudaMemInfo(logger());
0103 
0104   // ****************
0105   // Building Edges
0106   // ****************
0107 
0108   auto edgeList = detail::buildEdges(output, m_cfg.rVal, m_cfg.knnVal,
0109                                      m_cfg.shuffleDirections);
0110 
0111   ACTS_VERBOSE("Shape of built edges: (" << edgeList.size(0) << ", "
0112                                          << edgeList.size(1));
0113   ACTS_VERBOSE("Slice of edgelist:\n" << edgeList.slice(1, 0, 5));
0114   printCudaMemInfo(logger());
0115 
0116   return {std::move(inputTensors[0]).toTensor(), std::move(edgeList)};
0117 }
0118 }  // namespace Acts