File indexing completed on 2025-12-17 09:19:50
0001 #include "onnxlib.h"
0002
0003 #include <iostream>
0004
0005 namespace onnxlib
0006 {
0007 int n_input {-1};
0008 int n_output {-1};
0009 }
0010
0011 Ort::Session *onnxSession(std::string &modelfile, int verbosity)
0012 {
0013 Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "fit");
0014 Ort::SessionOptions sessionOptions;
0015 sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0016 auto *session = new Ort::Session(env, modelfile.c_str(), sessionOptions);
0017 auto type_info = session->GetInputTypeInfo(0);
0018 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0019 auto input_dims = tensor_info.GetShape();
0020 onnxlib::n_input = input_dims[1];
0021 type_info = session->GetOutputTypeInfo(0);
0022 tensor_info = type_info.GetTensorTypeAndShapeInfo();
0023 auto output_dims = tensor_info.GetShape();
0024 onnxlib::n_output = output_dims[1];
0025 if (verbosity > 0)
0026 {
0027 std::cout << "onnxlib: using model " << modelfile << std::endl;
0028 std::cout << "Number of Inputs: " << onnxlib::n_input << std::endl;
0029 std::cout << "Number of Outputs: " << onnxlib::n_output << std::endl;
0030 }
0031 return session;
0032 }
0033
0034 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nsamp, int Nreturn)
0035 {
0036 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0037
0038 Ort::AllocatorWithDefaultOptions allocator;
0039
0040 std::vector<Ort::Value> inputTensors;
0041 std::vector<Ort::Value> outputTensors;
0042
0043 std::vector<int64_t> inputDimsN = {N, Nsamp};
0044 std::vector<int64_t> outputDimsN = {N, Nreturn};
0045 int inputlen = N * Nsamp;
0046 int outputlen = N * Nreturn;
0047
0048 std::vector<float> outputTensorValuesN(outputlen);
0049
0050 inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDimsN.data(), inputDimsN.size()));
0051 outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValuesN.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0052
0053 #if ORT_API_VERSION == 12
0054 std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0055 std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0056 #elif ORT_API_VERSION == 22
0057 std::vector<const char *> inputNames;
0058 std::vector<const char *> outputNames;
0059
0060 char *name{nullptr};
0061 for (const std::string &s : session->GetInputNames())
0062 {
0063 name = new char[s.size() + 1];
0064 sprintf(name, "%s", s.c_str());
0065 inputNames.push_back(name);
0066 }
0067 for (const std::string &s : session->GetOutputNames())
0068 {
0069 name = new char[s.size() + 1];
0070 sprintf(name, "%s", s.c_str());
0071 outputNames.push_back(name);
0072 }
0073 #else
0074 #define XSTR(x) STR(x)
0075 #define STR(x) #x
0076 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0077 #endif
0078 session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0079
0080 #if ORT_API_VERSION == 22
0081 for (const auto *iter : inputNames)
0082 {
0083 delete[] iter;
0084 }
0085 for (const auto *iter : outputNames)
0086 {
0087 delete[] iter;
0088 }
0089 #endif
0090 return outputTensorValuesN;
0091 }
0092
0093 std::vector<float> onnxInference(Ort::Session *session, std::vector<float> &input, int N, int Nx, int Ny, int Nz, int Nreturn)
0094 {
0095
0096 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
0097
0098 Ort::AllocatorWithDefaultOptions allocator;
0099
0100 std::vector<int64_t> inputDims = {N, Nx, Ny, Nz};
0101 std::vector<int64_t> outputDimsN = {N, Nreturn};
0102 int inputlen = N * Nx * Ny * Nz;
0103 int outputlen = N * Nreturn;
0104
0105 std::vector<float> outputTensorValues(outputlen);
0106
0107 std::vector<Ort::Value> inputTensors;
0108 std::vector<Ort::Value> outputTensors;
0109
0110 inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, input.data(), inputlen, inputDims.data(), inputDims.size()));
0111
0112 outputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo, outputTensorValues.data(), outputlen, outputDimsN.data(), outputDimsN.size()));
0113
0114 #if ORT_API_VERSION == 12
0115 std::vector<const char *> inputNames{session->GetInputName(0, allocator)};
0116 std::vector<const char *> outputNames{session->GetOutputName(0, allocator)};
0117 #elif ORT_API_VERSION == 22
0118 std::vector<const char *> inputNames;
0119 std::vector<const char *> outputNames;
0120 char *name{nullptr};
0121 for (const std::string &s : session->GetInputNames())
0122 {
0123 name = new char[s.size() + 1];
0124 sprintf(name, "%s", s.c_str());
0125 inputNames.push_back(name);
0126 }
0127 for (const std::string &s : session->GetOutputNames())
0128 {
0129 name = new char[s.size() + 1];
0130 sprintf(name, "%s", s.c_str());
0131 outputNames.push_back(name);
0132 }
0133 #else
0134 #define XSTR(x) STR(x)
0135 #define STR(x) #x
0136 #pragma message "ORT_API_VERSION " XSTR(ORT_API_VERSION) " not implemented"
0137 #endif
0138 session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), 1, outputNames.data(), outputTensors.data(), 1);
0139 #if ORT_API_VERSION == 22
0140 for (const auto *iter : inputNames)
0141 {
0142 delete[] iter;
0143 }
0144 for (const auto *iter : outputNames)
0145 {
0146 delete[] iter;
0147 }
0148 #endif
0149
0150 return outputTensorValues;
0151 }