File indexing completed on 2025-12-16 09:18:05
0001 #include <TSystem.h>
0002 #include <TStopwatch.h>
0003 #include <iostream>
0004 #include <vector>
0005
0006
0007 #include "/sphenix/user/jzhang1/testcode4all/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/src/PtCalculator.h" // SiCaloPt::PtCalculator & friends
0008 R__LOAD_LIBRARY(/sphenix/user/jzhang1/testcode4all/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/src/libPtCalc.so)
0009
0010
0011 struct DemoPaths
0012 {
0013 std::string emd_onnx = "/sphenix/user/jzhang1/testcode4all/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLEMD.onnx";
0014 std::string emd_scaler_json = "";
0015
0016 std::string eproj_onnx = "/sphenix/user/jzhang1/testcode4all/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLEproj.onnx";
0017 std::string eproj_scaler_json = "/sphenix/user/jzhang1/testcode4all/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/scaler_MLEproj.json";
0018
0019 std::string combined_onnx = "/sphenix/user/jzhang1/testcode4all/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLCombined.onnx";
0020 std::string combined_scaler_json = "";
0021 };
0022
0023
0024 template<typename Opt>
0025 Opt make_opt(const std::string& s)
0026 {
0027 if (s.empty()) return std::nullopt;
0028 return s;
0029 }
0030
0031
0032 void PtCalcMLTutorial()
0033 {
0034
0035
0036
0037
0038
0039 DemoPaths WS_Path;
0040
0041
0042 SiCaloPt::PtCalculatorConfig cfg;
0043 cfg.mlEMD_model_path = make_opt<decltype(cfg.mlEMD_model_path)>(WS_Path.emd_onnx);
0044 cfg.mlEMD_scaler_json = make_opt<decltype(cfg.mlEMD_scaler_json)>(WS_Path.emd_scaler_json);
0045 cfg.mlEproj_model_path = make_opt<decltype(cfg.mlEproj_model_path)>(WS_Path.eproj_onnx);
0046 cfg.mlEproj_scaler_json = make_opt<decltype(cfg.mlEproj_scaler_json)>(WS_Path.eproj_scaler_json);
0047 cfg.mlCombined_model_path = make_opt<decltype(cfg.mlCombined_model_path)>(WS_Path.combined_onnx);
0048 cfg.mlCombined_scaler_json = make_opt<decltype(cfg.mlCombined_scaler_json)>(WS_Path.combined_scaler_json);
0049
0050
0051 SiCaloPt::PtCalculator calcTutorial(cfg);
0052
0053
0054 std::string err;
0055 if (!calcTutorial.init(&err))
0056 {
0057 std::cout << "[init] failed: " << err << std::endl;
0058 return;
0059 }
0060 std::cout << "[init] OK\n";
0061
0062
0063 {
0064 SiCaloPt::InputEMD in;
0065 in.EMD_Angle = 0.025;
0066 in.EMD_Eta = 0.00;
0067 in.EMD_Radius = 93.5;
0068
0069 calcTutorial.setParCeta(0.2);
0070 calcTutorial.setParPower(-1.0);
0071
0072 auto r = calcTutorial.ComputePt(SiCaloPt::Method::MethodEMD, SiCaloPt::AnyInput{in});
0073 std::cout << "[EMD-analytic] ok=" << r.ok
0074 << " pt=" << r.pt_reco
0075 << " err=\"" << r.err << "\"\n";
0076 }
0077
0078
0079 {
0080 SiCaloPt::InputEproj in;
0081 in.Energy_Calo = 1.8;
0082 in.Radius_Calo = 93.5;
0083 in.Z_Calo = 0.0;
0084 in.Radius_vertex = 0.0;
0085 in.Z_vertex = 0.0;
0086
0087 auto r = calcTutorial.ComputePt(SiCaloPt::Method::MethodEproj, SiCaloPt::AnyInput{in});
0088 std::cout << "[Eproj-analytic] ok=" << r.ok
0089 << " pt=" << r.pt_reco
0090 << " err=\"" << r.err << "\"\n";
0091 }
0092
0093
0094 {
0095
0096 std::vector<float> featsMLEMD = {15, 0};
0097
0098 SiCaloPt::InputMLEMD in{featsMLEMD};
0099 auto r = calcTutorial.ComputePt(SiCaloPt::Method::MethodMLEMD, SiCaloPt::AnyInput{in});
0100 std::cout << "[MLEMD-2D] ok=" << r.ok
0101 << " pt=" << r.pt_reco
0102 << " err=\"" << r.err << "\"\n";
0103 }
0104
0105
0106 {
0107
0108 std::vector<float> featsMLEproj = { 10.0, 5.0,
0109 15.0, 7.5,
0110 100.0, 50.0, 8.0 };
0111
0112 SiCaloPt::InputMLEproj in{featsMLEproj};
0113 auto r = calcTutorial.ComputePt(SiCaloPt::Method::MethodMLEproj, SiCaloPt::AnyInput{in});
0114 std::cout << "[MLEproj-7D] ok=" << r.ok
0115 << " pt=" << r.pt_reco
0116 << " err=\"" << r.err << "\"\n";
0117 }
0118
0119
0120 {
0121
0122 std::vector<float> featsMLCombined = {8.0, 9.5};
0123
0124 SiCaloPt::InputMLCombined in{featsMLCombined};
0125 auto r = calcTutorial.ComputePt(SiCaloPt::Method::MethodMLCombined, SiCaloPt::AnyInput{in});
0126 std::cout << "[MLCombined] ok=" << r.ok
0127 << " pt=" << r.pt_reco
0128 << " err=\"" << r.err << "\"\n";
0129 }
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151
0152
0153
0154
0155 }