File indexing completed on 2025-12-16 09:19:49
0001 #include "CaloWaveformProcessing.h"
0002 #include "CaloWaveformFitting.h"
0003
0004 #include <ffamodules/CDBInterface.h>
0005
0006 #include <phool/onnxlib.h>
0007
0008 #include <algorithm> // for max
0009 #include <cassert>
0010 #include <cstdlib> // for getenv
0011 #include <iostream>
0012 #include <limits>
0013 #include <memory> // for allocator_traits<>::value_type
0014 #include <string>
0015
0016 namespace
0017 {
0018 Ort::Session *onnxmodule;
0019 }
0020
0021 CaloWaveformProcessing::~CaloWaveformProcessing()
0022 {
0023 delete m_Fitter;
0024 }
0025
0026 void CaloWaveformProcessing::initialize_processing()
0027 {
0028 char *calibrationsroot = getenv("CALIBRATIONROOT");
0029 assert(calibrationsroot);
0030 if (m_processingtype == CaloWaveformProcessing::TEMPLATE || m_processingtype == CaloWaveformProcessing::TEMPLATE_NOSAT)
0031 {
0032 std::string calibrations_repo_template = std::string(calibrationsroot) + "/WaveformProcessing/templates/" + m_template_input_file;
0033 url_template = CDBInterface::instance()->getUrl(m_template_name, calibrations_repo_template);
0034 m_Fitter = new CaloWaveformFitting();
0035 m_Fitter->initialize_processing(url_template);
0036 if (m_processingtype == CaloWaveformProcessing::TEMPLATE_NOSAT)
0037 {
0038 m_Fitter->set_handleSaturation(false);
0039 }
0040 m_Fitter->set_nthreads(get_nthreads());
0041 if (m_setTimeLim)
0042 {
0043 m_Fitter->set_timeFitLim(m_timeLim_low, m_timeLim_high);
0044 }
0045
0046 if (_bdosoftwarezerosuppression)
0047 {
0048 m_Fitter->set_softwarezerosuppression(_bdosoftwarezerosuppression, _nsoftwarezerosuppression);
0049 }
0050 if (_dobitfliprecovery)
0051 {
0052 m_Fitter->set_bitFlipRecovery(_dobitfliprecovery);
0053 }
0054 }
0055 else if (m_processingtype == CaloWaveformProcessing::ONNX)
0056 {
0057
0058
0059 onnxmodule = onnxSession(m_model_name, Verbosity());
0060 }
0061 else if (m_processingtype == CaloWaveformProcessing::NYQUIST)
0062 {
0063 std::string calibrations_repo_template = std::string(calibrationsroot) + "/WaveformProcessing/templates/" + m_template_input_file;
0064 url_template = CDBInterface::instance()->getUrl(m_template_name, calibrations_repo_template);
0065 m_Fitter = new CaloWaveformFitting();
0066 m_Fitter->initialize_processing(url_template);
0067 }
0068 }
0069
0070 std::vector<std::vector<float>> CaloWaveformProcessing::process_waveform(std::vector<std::vector<float>> waveformvector)
0071 {
0072 unsigned int size1 = waveformvector.size();
0073 std::vector<std::vector<float>> fitresults;
0074 if (m_processingtype == CaloWaveformProcessing::TEMPLATE || m_processingtype == CaloWaveformProcessing::TEMPLATE_NOSAT)
0075 {
0076 for (unsigned int i = 0; i < size1; i++)
0077 {
0078 waveformvector.at(i).push_back((float) i);
0079 }
0080 fitresults = m_Fitter->calo_processing_templatefit(waveformvector);
0081 }
0082 if (m_processingtype == CaloWaveformProcessing::ONNX)
0083 {
0084 fitresults = CaloWaveformProcessing::calo_processing_ONNX(waveformvector);
0085 }
0086 if (m_processingtype == CaloWaveformProcessing::FAST)
0087 {
0088 fitresults = CaloWaveformFitting::calo_processing_fast(waveformvector);
0089 }
0090 if (m_processingtype == CaloWaveformProcessing::NYQUIST)
0091 {
0092 fitresults = m_Fitter->calo_processing_nyquist(waveformvector);
0093 }
0094 return fitresults;
0095 }
0096
0097 std::vector<std::vector<float>> CaloWaveformProcessing::calo_processing_ONNX(const std::vector<std::vector<float>> &chnlvector)
0098 {
0099 std::vector<std::vector<float>> fit_values;
0100 std::vector<float> val;
0101 unsigned int nchnls = chnlvector.size();
0102 for (unsigned int m = 0; m < nchnls; m++)
0103 {
0104 val.clear();
0105 const std::vector<float> &v = chnlvector.at(m);
0106 int size1 = v.size();
0107 if (size1 == _nzerosuppresssamples)
0108 {
0109 val.push_back(v.at(1) - v.at(0));
0110 val.push_back(std::numeric_limits<float>::quiet_NaN());
0111 val.push_back(v.at(0));
0112 if (v.at(0) != 0 && v.at(1) == 0)
0113 {
0114 val.push_back(1000000);
0115 }
0116 else
0117 {
0118 val.push_back(std::numeric_limits<float>::quiet_NaN());
0119 }
0120 val.push_back(0);
0121 fit_values.push_back(val);
0122 }
0123 else
0124 {
0125 float maxheight = 0;
0126 int maxbin = 0;
0127 for (int i = 0; i < size1; i++)
0128 {
0129 if (v.at(i) > maxheight)
0130 {
0131 maxheight = v.at(i);
0132 maxbin = i;
0133 }
0134 }
0135 float pedestal = 1500;
0136 if (maxbin > 4)
0137 {
0138 pedestal = 0.5 * (v.at(maxbin - 4) + v.at(maxbin - 5));
0139 }
0140 else if (maxbin > 3)
0141 {
0142 pedestal = (v.at(maxbin - 4));
0143 }
0144 else
0145 {
0146 pedestal = 0.5 * (v.at(size1 - 3) + v.at(size1 - 2));
0147 }
0148
0149 if ((_bdosoftwarezerosuppression && v.at(6) - v.at(0) < _nsoftwarezerosuppression) || (_maxsoftwarezerosuppression && maxheight - pedestal < _nsoftwarezerosuppression))
0150 {
0151 val.push_back(v.at(6) - v.at(0));
0152 val.push_back(std::numeric_limits<float>::quiet_NaN());
0153 val.push_back(v.at(0));
0154 if (v.at(0) != 0 && v.at(1) == 0)
0155 {
0156 val.push_back(1000000);
0157 }
0158 else
0159 {
0160 val.push_back(std::numeric_limits<float>::quiet_NaN());
0161 }
0162 val.push_back(0);
0163 fit_values.push_back(val);
0164 }
0165 else
0166 {
0167 unsigned int nsamples = v.size();
0168 if (nsamples == 12)
0169 {
0170
0171
0172 std::vector<float> vtmp(v);
0173 val = onnxInference(onnxmodule, vtmp, 1, onnxlib::n_input, onnxlib::n_output);
0174 unsigned int nvals = val.size();
0175 for (unsigned int i = 0; i < nvals; i++)
0176 {
0177 val.at(i) = val.at(i) * m_Onnx_factor.at(i) + m_Onnx_offset.at(i);
0178 }
0179 val.push_back(2000);
0180 val.push_back(0);
0181 fit_values.push_back(val);
0182 }
0183 else
0184 {
0185 float v_diff = v[1] - v[0];
0186 std::vector<float> val1{v_diff, std::numeric_limits<float>::quiet_NaN(), v[1], std::numeric_limits<float>::quiet_NaN(), 0};
0187 fit_values.push_back(val1);
0188 }
0189 }
0190 }
0191 }
0192 return fit_values;
0193 }
0194
0195 int CaloWaveformProcessing::get_nthreads()
0196 {
0197 if (m_Fitter)
0198 {
0199 return m_Fitter->get_nthreads();
0200 }
0201 return _nthreads;
0202 }
0203 void CaloWaveformProcessing::set_nthreads(int nthreads)
0204 {
0205 _nthreads = nthreads;
0206 if (m_Fitter)
0207 {
0208 m_Fitter->set_nthreads(nthreads);
0209 }
0210 return;
0211 }