File indexing completed on 2025-12-16 09:19:51
0001
0002
0003
0004
0005 #include "RawClusterCNNClassifier.h"
0006
0007 #include <calobase/RawCluster.h>
0008 #include <calobase/RawClusterContainer.h>
0009 #include <calobase/RawClusterUtility.h>
0010 #include <calobase/RawTower.h>
0011 #include <calobase/RawTowerContainer.h>
0012 #include <calobase/RawTowerDefs.h>
0013 #include <calobase/RawTowerGeom.h>
0014 #include <calobase/RawTowerGeomContainer.h>
0015
0016
0017 #include <calobase/TowerInfo.h>
0018 #include <calobase/TowerInfoContainer.h>
0019 #include <calobase/TowerInfoDefs.h>
0020
0021 #include <ffaobjects/EventHeader.h>
0022
0023 #include <fun4all/Fun4AllReturnCodes.h>
0024 #include <fun4all/Fun4AllServer.h>
0025
0026 #include <phool/PHCompositeNode.h>
0027 #include <phool/getClass.h>
0028 #include <phool/onnxlib.h>
0029 #include <phool/phool.h>
0030
0031 #include <iostream>
0032 #include <vector>
0033
0034 RawClusterCNNClassifier::RawClusterCNNClassifier(const std::string &name)
0035 : SubsysReco(name)
0036 {
0037 }
0038
0039 RawClusterCNNClassifier::~RawClusterCNNClassifier()
0040 {
0041 delete onnxmodule;
0042 }
0043
0044 int RawClusterCNNClassifier::Init(PHCompositeNode *topNode)
0045 {
0046
0047 onnxmodule = onnxSession(m_modelPath,Verbosity());
0048
0049 if (m_inputNodeName == m_outputNodeName)
0050 {
0051 std::cout << "RawClusterCNNClassifier::Init: inputNodeName and outputNodeName are the same, setting inplace to true" << std::endl;
0052 inplace = true;
0053 }
0054 CreateNodes(topNode);
0055
0056 return Fun4AllReturnCodes::EVENT_OK;
0057 }
0058
0059 int RawClusterCNNClassifier::process_event(PHCompositeNode *topNode)
0060 {
0061
0062 std::string clusterNodeName = m_inputNodeName;
0063 RawClusterContainer *clusterContainer = findNode::getClass<RawClusterContainer>(topNode, clusterNodeName);
0064 if (!clusterContainer)
0065 {
0066 std::cout << "RawClusterCNNClassifier::process_event::Could not locate input cluster node " << clusterNodeName << std::endl;
0067 return Fun4AllReturnCodes::ABORTEVENT;
0068 }
0069
0070 if (inplace)
0071 {
0072 _clusters = clusterContainer;
0073 }
0074 else
0075 {
0076 _clusters->Reset();
0077 RawClusterContainer::Map clusterMap = clusterContainer->getClustersMap();
0078 for (auto &clusterPair : clusterMap)
0079 {
0080
0081 RawCluster *recoCluster = (clusterPair.second);
0082 float clusterE = recoCluster->get_energy();
0083 if (clusterE < m_min_cluster_e)
0084 {
0085 continue;
0086 }
0087 RawCluster *newCluster = (RawCluster *) recoCluster->CloneMe();
0088 _clusters->AddCluster(newCluster);
0089 }
0090 }
0091
0092
0093 std::string towerNodeName = m_towerNodeName;
0094 TowerInfoContainer *emcTowerContainer = findNode::getClass<TowerInfoContainer>(topNode, towerNodeName);
0095 if (!emcTowerContainer)
0096 {
0097 std::cout << "RawClusterCNNClassifier::process_event Could not locate tower node " << towerNodeName << std::endl;
0098 return Fun4AllReturnCodes::ABORTEVENT;
0099 }
0100
0101 RawClusterContainer::Map clusterMap = _clusters->getClustersMap();
0102 for (auto &clusterPair : clusterMap)
0103 {
0104 RawCluster *recoCluster = clusterPair.second;
0105
0106 recoCluster->set_prob(-1);
0107 CLHEP::Hep3Vector vertex(0, 0, 0);
0108 CLHEP::Hep3Vector E_vec_cluster_Full = RawClusterUtility::GetEVec(*recoCluster, vertex);
0109 float ET = E_vec_cluster_Full.perp();
0110 if (ET < minET)
0111 {
0112 continue;
0113 }
0114 const RawCluster::TowerMap tower_map =
0115 recoCluster->get_towermap();
0116
0117 int maxtowerE = 0;
0118 int maxtowerieta = -1;
0119 int maxtoweriphi = -1;
0120
0121 for (auto tower_iter : tower_map)
0122 {
0123 RawTowerDefs::keytype tower_key = tower_iter.first;
0124
0125
0126 int ix = RawTowerDefs::decode_index2(tower_key);
0127 int iy = RawTowerDefs::decode_index1(tower_key);
0128 RawTowerDefs::CalorimeterId caloid = RawTowerDefs::decode_caloid(tower_key);
0129
0130 if (caloid != RawTowerDefs::CalorimeterId::CEMC)
0131 {
0132 continue;
0133 }
0134
0135 unsigned int towerinfokey = TowerInfoDefs::encode_emcal(iy, ix);
0136 TowerInfo *towerinfo = emcTowerContainer->get_tower_at_key(towerinfokey);
0137 if (!towerinfo)
0138 {
0139
0140 std::cout << "No towerinfo for tower key " << towerinfokey
0141 << std::endl;
0142 continue;
0143 }
0144 float towerE = towerinfo->get_energy();
0145 if (towerE > maxtowerE)
0146 {
0147 maxtowerE = towerE;
0148 maxtowerieta = iy;
0149 maxtoweriphi = ix;
0150 }
0151 }
0152
0153 std::vector<float> input;
0154
0155 int vectorSize = inputDimx * inputDimy;
0156 input.resize(vectorSize, 0);
0157
0158 if (maxtowerE > 0)
0159 {
0160 int xlength = ((inputDimx - 1) / 2);
0161 int ylength = ((inputDimy - 1) / 2);
0162 if (maxtowerieta - ylength < 0 || maxtowerieta + ylength >= 96)
0163 {
0164 continue;
0165 }
0166 for (int ieta = maxtowerieta - ylength; ieta <= maxtowerieta + ylength; ieta++)
0167 {
0168 for (int iphi = maxtoweriphi - xlength; iphi <= maxtoweriphi + xlength; iphi++)
0169 {
0170 int mappediphi = iphi;
0171
0172 if (mappediphi < 0)
0173 {
0174 mappediphi += 256;
0175 }
0176 if (mappediphi > 255)
0177 {
0178 mappediphi -= 256;
0179 }
0180 unsigned int towerinfokey = TowerInfoDefs::encode_emcal(ieta, mappediphi);
0181 TowerInfo *towerinfo = emcTowerContainer->get_tower_at_key(towerinfokey);
0182 if (!towerinfo)
0183 {
0184
0185 std::cout << "No towerinfo for tower key " << towerinfokey << std::endl;
0186 std::cout << "ieta: " << ieta << " iphi: " << mappediphi << std::endl;
0187 continue;
0188 }
0189 int index = ((ieta - maxtowerieta + ylength) * inputDimx) + iphi - maxtoweriphi + xlength;
0190 input.at(index) = towerinfo->get_energy();
0191 }
0192 }
0193 }
0194 std::vector<float> prob = onnxInference(onnxmodule, input, 1, inputDimx, inputDimy, inputDimz, outputDim);
0195
0196
0197 recoCluster->set_prob(prob[0]);
0198 }
0199
0200 return Fun4AllReturnCodes::EVENT_OK;
0201 }
0202
0203 void RawClusterCNNClassifier::CreateNodes(PHCompositeNode *topNode)
0204 {
0205 PHNodeIterator iter(topNode);
0206
0207
0208 PHCompositeNode *dstNode = dynamic_cast<PHCompositeNode *>(iter.findFirst("PHCompositeNode", "DST"));
0209 if (!dstNode)
0210 {
0211 std::cout << PHWHERE << "DST Node missing, doing nothing." << std::endl;
0212 throw std::runtime_error("Failed to find DST node in EmcRawTowerBuilder::CreateNodes");
0213 }
0214
0215
0216 PHCompositeNode *cemcNode = dynamic_cast<PHCompositeNode *>(iter.findFirst("PHCompositeNode", "CEMC"));
0217
0218
0219 if (!cemcNode)
0220 {
0221 cemcNode = new PHCompositeNode("CEMC");
0222 dstNode->addNode(cemcNode);
0223 }
0224 std::string clusterNodeName = m_outputNodeName;
0225 _clusters = findNode::getClass<RawClusterContainer>(dstNode, clusterNodeName);
0226 if (!_clusters)
0227 {
0228 _clusters = new RawClusterContainer();
0229 PHIODataNode<PHObject> *clusterNode = new PHIODataNode<PHObject>(_clusters, clusterNodeName, "PHObject");
0230 cemcNode->addNode(clusterNode);
0231 }
0232 }