File indexing completed on 2025-08-05 08:10:17
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
0010
0011 #include "Acts/Plugins/ExaTrkX/detail/TensorVectorConversion.hpp"
0012 #include "Acts/Plugins/ExaTrkX/detail/buildEdges.hpp"
0013
0014 #include <torch/script.h>
0015 #include <torch/torch.h>
0016
0017 #include "printCudaMemInfo.hpp"
0018
0019 using namespace torch::indexing;
0020
0021 namespace Acts {
0022
0023 TorchMetricLearning::TorchMetricLearning(const Config &cfg,
0024 std::unique_ptr<const Logger> _logger)
0025 : m_logger(std::move(_logger)), m_cfg(cfg) {
0026 c10::InferenceMode guard(true);
0027 m_deviceType = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
0028 ACTS_DEBUG("Using torch version " << TORCH_VERSION_MAJOR << "."
0029 << TORCH_VERSION_MINOR << "."
0030 << TORCH_VERSION_PATCH);
0031 #ifndef ACTS_EXATRKX_CPUONLY
0032 if (not torch::cuda::is_available()) {
0033 ACTS_INFO("CUDA not available, falling back to CPU");
0034 }
0035 #endif
0036
0037 try {
0038 m_model = std::make_unique<torch::jit::Module>();
0039 *m_model = torch::jit::load(m_cfg.modelPath, m_deviceType);
0040 m_model->eval();
0041 } catch (const c10::Error &e) {
0042 throw std::invalid_argument("Failed to load models: " + e.msg());
0043 }
0044 }
0045
0046 TorchMetricLearning::~TorchMetricLearning() {}
0047
0048 std::tuple<std::any, std::any> TorchMetricLearning::operator()(
0049 std::vector<float> &inputValues, std::size_t numNodes, int deviceHint) {
0050 ACTS_DEBUG("Start graph construction");
0051 c10::InferenceMode guard(true);
0052 const torch::Device device(m_deviceType, deviceHint);
0053
0054 const int64_t numAllFeatures = inputValues.size() / numNodes;
0055
0056
0057 ACTS_VERBOSE("First spacepoint information: " << [&]() {
0058 std::stringstream ss;
0059 for (int i = 0; i < numAllFeatures; ++i) {
0060 ss << inputValues[i] << " ";
0061 }
0062 return ss.str();
0063 }());
0064 printCudaMemInfo(logger());
0065
0066 auto inputTensor = detail::vectorToTensor2D(inputValues, numAllFeatures);
0067
0068
0069
0070 if (inputTensor.options().device() == device) {
0071 inputTensor = inputTensor.clone();
0072 } else {
0073 inputTensor = inputTensor.to(device);
0074 }
0075
0076
0077
0078
0079
0080 if (m_cfg.numFeatures > numAllFeatures) {
0081 throw std::runtime_error("requested more features then available");
0082 }
0083
0084
0085 auto model = m_model->clone();
0086 model.to(device);
0087
0088 std::vector<torch::jit::IValue> inputTensors;
0089 inputTensors.push_back(
0090 m_cfg.numFeatures < numAllFeatures
0091 ? inputTensor.index({Slice{}, Slice{None, m_cfg.numFeatures}})
0092 : std::move(inputTensor));
0093
0094 ACTS_DEBUG("embedding input tensor shape "
0095 << inputTensors[0].toTensor().size(0) << ", "
0096 << inputTensors[0].toTensor().size(1));
0097
0098 auto output = model.forward(inputTensors).toTensor();
0099
0100 ACTS_VERBOSE("Embedding space of the first SP:\n"
0101 << output.slice(0, 0, 1));
0102 printCudaMemInfo(logger());
0103
0104
0105
0106
0107
0108 auto edgeList = detail::buildEdges(output, m_cfg.rVal, m_cfg.knnVal,
0109 m_cfg.shuffleDirections);
0110
0111 ACTS_VERBOSE("Shape of built edges: (" << edgeList.size(0) << ", "
0112 << edgeList.size(1));
0113 ACTS_VERBOSE("Slice of edgelist:\n" << edgeList.slice(1, 0, 5));
0114 printCudaMemInfo(logger());
0115
0116 return {std::move(inputTensors[0]).toTensor(), std::move(edgeList)};
0117 }
0118 }