Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-17 09:19:50

0001 #include <onnxruntime_cxx_api.h>
0002 #include <iostream>
0003 #include <string>
0004 #include <vector>
0005 
0006 int main(int argc, char* argv[])
0007 {
0008   if (argc < 2)
0009   {
0010     std::cerr << "Usage: " << argv[0] << " model.onnx" << std::endl;
0011     return 1;
0012   }
0013 
0014   const std::string model_path = argv[1];
0015 
0016   try
0017   {
0018     Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "meta_reader");
0019     Ort::SessionOptions session_options;
0020     session_options.SetIntraOpNumThreads(1);
0021 
0022     // Create session to load the model
0023     Ort::Session session(env, model_path.c_str(), session_options);
0024 
0025     std::cout << "✅ Model loaded successfully: " << model_path << "\n";
0026     std::cout << "----------------------------------------\n";
0027 
0028     // Print model input information
0029     Ort::AllocatorWithDefaultOptions allocator;
0030 
0031     size_t num_input_nodes = session.GetInputCount();
0032     std::vector<std::string> inputnames = session.GetInputNames();
0033     std::cout << "Inputs (" << num_input_nodes << "):\n";
0034     for (size_t i = 0; i < num_input_nodes; i++)
0035     {
0036       auto type_info = session.GetInputTypeInfo(i);
0037       auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0038 
0039       ONNXTensorElementDataType type = tensor_info.GetElementType();
0040       auto input_dims = tensor_info.GetShape();
0041 
0042       std::cout << "  • " << inputnames[i] << " (type=" << type << ", shape=[";
0043       for (size_t j = 0; j < input_dims.size(); j++)
0044       {
0045         std::cout << input_dims[j];
0046         if (j + 1 < input_dims.size())
0047         {
0048           std::cout << ", ";
0049         }
0050       }
0051       std::cout << "])\n";
0052     }
0053 
0054     // Print model output information
0055     size_t num_output_nodes = session.GetOutputCount();
0056     std::cout << "\nOutputs (" << num_output_nodes << "):\n";
0057     std::vector<std::string> outputnames = session.GetOutputNames();
0058     for (size_t i = 0; i < num_output_nodes; i++)
0059     {
0060       auto type_info = session.GetOutputTypeInfo(i);
0061       auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
0062 
0063       ONNXTensorElementDataType type = tensor_info.GetElementType();
0064       auto output_dims = tensor_info.GetShape();
0065 
0066       std::cout << "  • " << outputnames[i] << " (type=" << type << ", shape=[";
0067       for (size_t j = 0; j < output_dims.size(); j++)
0068       {
0069         std::cout << output_dims[j];
0070         if (j + 1 < output_dims.size())
0071         {
0072           std::cout << ", ";
0073         }
0074       }
0075       std::cout << "])\n";
0076     }
0077 
0078     std::cout << "\n----------------------------------------\n";
0079     std::cout << "ONNX Runtime Version: " << Ort::GetVersionString() << "\n";
0080   }
0081   catch (const Ort::Exception& e)
0082   {
0083     std::cerr << "❌ ONNX Runtime Error: " << e.what() << std::endl;
0084     return 1;
0085   }
0086 
0087   return 0;
0088 }