Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:16:19

0001 #include "onnxlib.h"
0002 #include <iostream>
0003 
0004 Ort::Session *onnxSession(std::string &modelfile)
0005 {
0006   Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "fit");
0007   Ort::SessionOptions sessionOptions;
0008   sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0009 
0010   return new Ort::Session(env, modelfile.c_str(), sessionOptions);
0011 }
0012 
0013 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nsamp, int Nreturn)
0014 {
0015   Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0016 
0017   Ort::AllocatorWithDefaultOptions allocator;
0018 
0019   std::vector<Ort::Value> inputTensors;
0020   std::vector<Ort::Value> outputTensors;
0021 
0022   std::vector<int64_t> inputDimsN = {N, Nsamp};
0023   std::vector<int64_t> outputDimsN = {N, Nreturn};
0024   int inputlen = N * Nsamp;
0025   int outputlen = N * Nreturn;
0026 
0027   std::vector<float> outputTensorValuesN(outputlen);
0028 
0029   inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDimsN.data(), inputDimsN.size()));
0030   outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValuesN.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0031 
0032 #if ORT_API_VERSION == 12
0033   std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0034   std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0035 #elif ORT_API_VERSION == 22
0036   std::vector<const char *> inputNames;
0037   std::vector<const char *> outputNames;
0038 
0039   char *name{nullptr};
0040   for (const std::string &s : session->GetInputNames())
0041   {
0042     name = new char[s.size() + 1];
0043     sprintf(name, "%s", s.c_str());
0044     inputNames.push_back(name);
0045   }
0046   for (const std::string &s : session->GetOutputNames())
0047   {
0048     name = new char[s.size() + 1];
0049     sprintf(name, "%s", s.c_str());
0050     outputNames.push_back(name);
0051   }
0052 #else
0053 #define XSTR(x) STR(x)
0054 #define STR(x) #x
0055 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0056 #endif
0057   session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0058 
0059 #if ORT_API_VERSION == 22
0060   for (auto iter : inputNames)
0061   {
0062     delete[] iter;
0063   }
0064   for (auto iter : outputNames)
0065   {
0066     delete[] iter;
0067   }
0068 #endif
0069   return outputTensorValuesN;
0070 }
0071 
0072 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nx, int Ny, int Nz, int Nreturn)
0073 {
0074   // Define the memory information for ONNX Runtime
0075   Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0076 
0077   Ort::AllocatorWithDefaultOptions allocator;
0078 
0079   std::vector<int64_t> inputDims = {N, Nx, Ny, Nz};
0080   std::vector<int64_t> outputDimsN = {N, Nreturn};
0081   int inputlen = N * Nx * Ny * Nz;
0082   int outputlen = N * Nreturn;
0083 
0084   std::vector<float> outputTensorValues(outputlen);
0085 
0086   std::vector<Ort::Value> inputTensors;
0087   std::vector<Ort::Value> outputTensors;
0088 
0089   inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDims.data(), inputDims.size()));
0090 
0091   outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValues.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0092 
0093 #if ORT_API_VERSION == 12
0094   std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0095   std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0096 #elif ORT_API_VERSION == 22
0097   std::vector<const char *> inputNames;
0098   std::vector<const char *> outputNames;
0099   char *name{nullptr};
0100   for (const std::string &s : session->GetInputNames())
0101   {
0102     name = new char[s.size() + 1];
0103     sprintf(name, "%s", s.c_str());
0104     inputNames.push_back(name);
0105   }
0106   for (const std::string &s : session->GetOutputNames())
0107   {
0108     name = new char[s.size() + 1];
0109     sprintf(name, "%s", s.c_str());
0110     outputNames.push_back(name);
0111   }
0112 #else
0113 #define XSTR(x) STR(x)
0114 #define STR(x) #x
0115 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0116 #endif
0117   session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0118 #if ORT_API_VERSION == 22
0119   for (auto iter : inputNames)
0120   {
0121     delete[] iter;
0122   }
0123   for (auto iter : outputNames)
0124   {
0125     delete[] iter;
0126   }
0127 #endif
0128 
0129   return outputTensorValues;
0130 }