File indexing completed on 2025-08-05 08:16:19
0001 #include "onnxlib.h"
0002 #include <iostream>
0003
0004 Ort::Session *onnxSession(std::string &modelfile)
0005 {
0006 Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "fit");
0007 Ort::SessionOptions sessionOptions;
0008 sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0009
0010 return new Ort::Session(env, modelfile.c_str(), sessionOptions);
0011 }
0012
0013 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nsamp, int Nreturn)
0014 {
0015 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0016
0017 Ort::AllocatorWithDefaultOptions allocator;
0018
0019 std::vector<Ort::Value> inputTensors;
0020 std::vector<Ort::Value> outputTensors;
0021
0022 std::vector<int64_t> inputDimsN = {N, Nsamp};
0023 std::vector<int64_t> outputDimsN = {N, Nreturn};
0024 int inputlen = N * Nsamp;
0025 int outputlen = N * Nreturn;
0026
0027 std::vector<float> outputTensorValuesN(outputlen);
0028
0029 inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDimsN.data(), inputDimsN.size()));
0030 outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValuesN.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0031
0032 #if ORT_API_VERSION == 12
0033 std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0034 std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0035 #elif ORT_API_VERSION == 22
0036 std::vector<const char *> inputNames;
0037 std::vector<const char *> outputNames;
0038
0039 char *name{nullptr};
0040 for (const std::string &s : session->GetInputNames())
0041 {
0042 name = new char[s.size() + 1];
0043 sprintf(name, "%s", s.c_str());
0044 inputNames.push_back(name);
0045 }
0046 for (const std::string &s : session->GetOutputNames())
0047 {
0048 name = new char[s.size() + 1];
0049 sprintf(name, "%s", s.c_str());
0050 outputNames.push_back(name);
0051 }
0052 #else
0053 #define XSTR(x) STR(x)
0054 #define STR(x) #x
0055 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0056 #endif
0057 session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0058
0059 #if ORT_API_VERSION == 22
0060 for (auto iter : inputNames)
0061 {
0062 delete[] iter;
0063 }
0064 for (auto iter : outputNames)
0065 {
0066 delete[] iter;
0067 }
0068 #endif
0069 return outputTensorValuesN;
0070 }
0071
0072 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nx, int Ny, int Nz, int Nreturn)
0073 {
0074
0075 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0076
0077 Ort::AllocatorWithDefaultOptions allocator;
0078
0079 std::vector<int64_t> inputDims = {N, Nx, Ny, Nz};
0080 std::vector<int64_t> outputDimsN = {N, Nreturn};
0081 int inputlen = N * Nx * Ny * Nz;
0082 int outputlen = N * Nreturn;
0083
0084 std::vector<float> outputTensorValues(outputlen);
0085
0086 std::vector<Ort::Value> inputTensors;
0087 std::vector<Ort::Value> outputTensors;
0088
0089 inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDims.data(), inputDims.size()));
0090
0091 outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValues.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0092
0093 #if ORT_API_VERSION == 12
0094 std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0095 std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0096 #elif ORT_API_VERSION == 22
0097 std::vector<const char *> inputNames;
0098 std::vector<const char *> outputNames;
0099 char *name{nullptr};
0100 for (const std::string &s : session->GetInputNames())
0101 {
0102 name = new char[s.size() + 1];
0103 sprintf(name, "%s", s.c_str());
0104 inputNames.push_back(name);
0105 }
0106 for (const std::string &s : session->GetOutputNames())
0107 {
0108 name = new char[s.size() + 1];
0109 sprintf(name, "%s", s.c_str());
0110 outputNames.push_back(name);
0111 }
0112 #else
0113 #define XSTR(x) STR(x)
0114 #define STR(x) #x
0115 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0116 #endif
0117 session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0118 #if ORT_API_VERSION == 22
0119 for (auto iter : inputNames)
0120 {
0121 delete[] iter;
0122 }
0123 for (auto iter : outputNames)
0124 {
0125 delete[] iter;
0126 }
0127 #endif
0128
0129 return outputTensorValues;
0130 }