Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-17 09:19:50

0001 #include "onnxlib.h"
0002 
0003 #include <iostream>
0004 
0005 namespace onnxlib
0006 {
0007   int n_input {-1};
0008   int n_output {-1};
0009 }  // namespace onnxlib
0010 
0011 Ort::Session *onnxSession(std::string &modelfile, int verbosity)
0012 {
0013   Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "fit");
0014   Ort::SessionOptions sessionOptions;
0015   sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0016   auto *session = new Ort::Session(env, modelfile.c_str(), sessionOptions);
0017   auto type_info = session->GetInputTypeInfo(0);
0018   auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0019   auto input_dims = tensor_info.GetShape();
0020   onnxlib::n_input = input_dims[1];
0021   type_info = session->GetOutputTypeInfo(0);
0022   tensor_info = type_info.GetTensorTypeAndShapeInfo();
0023   auto output_dims = tensor_info.GetShape();
0024   onnxlib::n_output = output_dims[1];
0025   if (verbosity > 0)
0026   {
0027     std::cout << "onnxlib: using model " << modelfile << std::endl;
0028     std::cout << "Number of Inputs: " << onnxlib::n_input << std::endl;
0029     std::cout << "Number of Outputs: " << onnxlib::n_output << std::endl;
0030   }
0031   return session;
0032 }
0033 
0034 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nsamp, int Nreturn)
0035 {
0036   Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0037 
0038   Ort::AllocatorWithDefaultOptions allocator;
0039 
0040   std::vector<Ort::Value> inputTensors;
0041   std::vector<Ort::Value> outputTensors;
0042 
0043   std::vector<int64_t> inputDimsN = {N, Nsamp};
0044   std::vector<int64_t> outputDimsN = {N, Nreturn};
0045   int inputlen = N * Nsamp;
0046   int outputlen = N * Nreturn;
0047 
0048   std::vector<float> outputTensorValuesN(outputlen);
0049 
0050   inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDimsN.data(), inputDimsN.size()));
0051   outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValuesN.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0052 
0053 #if ORT_API_VERSION == 12
0054   std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0055   std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0056 #elif ORT_API_VERSION == 22
0057   std::vector<const char *> inputNames;
0058   std::vector<const char *> outputNames;
0059 
0060   char *name{nullptr};
0061   for (const std::string &s : session->GetInputNames())
0062   {
0063     name = new char[s.size() + 1];
0064     sprintf(name, "%s", s.c_str()); //NOLINT(hicpp-vararg)
0065     inputNames.push_back(name);
0066   }
0067   for (const std::string &s : session->GetOutputNames())
0068   {
0069     name = new char[s.size() + 1];
0070     sprintf(name, "%s", s.c_str()); //NOLINT(hicpp-vararg)
0071     outputNames.push_back(name);
0072   }
0073 #else
0074 #define XSTR(x) STR(x)
0075 #define STR(x) #x
0076 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0077 #endif
0078   session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0079 
0080 #if ORT_API_VERSION == 22
0081   for (const auto *iter : inputNames)
0082   {
0083     delete[] iter;
0084   }
0085   for (const auto *iter : outputNames)
0086   {
0087     delete[] iter;
0088   }
0089 #endif
0090   return outputTensorValuesN;
0091 }
0092 
0093 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nx, int Ny, int Nz, int Nreturn)
0094 {
0095   // Define the memory information for ONNX Runtime
0096   Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0097 
0098   Ort::AllocatorWithDefaultOptions allocator;
0099 
0100   std::vector<int64_t> inputDims = {N, Nx, Ny, Nz};
0101   std::vector<int64_t> outputDimsN = {N, Nreturn};
0102   int inputlen = N * Nx * Ny * Nz;
0103   int outputlen = N * Nreturn;
0104 
0105   std::vector<float> outputTensorValues(outputlen);
0106 
0107   std::vector<Ort::Value> inputTensors;
0108   std::vector<Ort::Value> outputTensors;
0109 
0110   inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDims.data(), inputDims.size()));
0111 
0112   outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValues.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0113 
0114 #if ORT_API_VERSION == 12
0115   std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0116   std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0117 #elif ORT_API_VERSION == 22
0118   std::vector<const char *> inputNames;
0119   std::vector<const char *> outputNames;
0120   char *name{nullptr};
0121   for (const std::string &s : session->GetInputNames())
0122   {
0123     name = new char[s.size() + 1];
0124     sprintf(name, "%s", s.c_str()); //NOLINT(hicpp-vararg)
0125     inputNames.push_back(name);
0126   }
0127   for (const std::string &s : session->GetOutputNames())
0128   {
0129     name = new char[s.size() + 1];
0130     sprintf(name, "%s", s.c_str()); //NOLINT(hicpp-vararg)
0131     outputNames.push_back(name);
0132   }
0133 #else
0134 #define XSTR(x) STR(x)
0135 #define STR(x) #x
0136 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0137 #endif
0138   session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0139 #if ORT_API_VERSION == 22
0140   for (const auto *iter : inputNames)
0141   {
0142     delete[] iter;
0143   }
0144   for (const auto *iter : outputNames)
0145   {
0146     delete[] iter;
0147   }
0148 #endif
0149 
0150   return outputTensorValues;
0151 }