File indexing completed on 2025-08-05 08:09:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "Acts/EventData/SourceLink.hpp"
0010 #include "Acts/Utilities/CalibrationContext.hpp"
0011 #include "Acts/Utilities/UnitVectors.hpp"
0012 #include <ActsExamples/EventData/NeuralCalibrator.hpp>
0013
0014 #include <TFile.h>
0015
0016 namespace detail {
0017
0018 template <typename Array>
0019 std::size_t fillChargeMatrix(Array& arr, const ActsExamples::Cluster& cluster,
0020 std::size_t size0 = 7u, std::size_t size1 = 7u) {
0021
0022
0023 double totalAct = 0;
0024 for (const ActsExamples::Cluster::Cell& cell : cluster.channels) {
0025 totalAct += cell.activation;
0026 }
0027 std::vector<double> weights;
0028 for (const ActsExamples::Cluster::Cell& cell : cluster.channels) {
0029 weights.push_back(cell.activation / totalAct);
0030 }
0031
0032 double acc0 = 0;
0033 double acc1 = 0;
0034 for (std::size_t i = 0; i < cluster.channels.size(); i++) {
0035 acc0 += cluster.channels.at(i).bin[0] * weights.at(i);
0036 acc1 += cluster.channels.at(i).bin[1] * weights.at(i);
0037 }
0038
0039
0040
0041 int offset0 = static_cast<int>(acc0) - size0 / 2;
0042 int offset1 = static_cast<int>(acc1) - size1 / 2;
0043
0044
0045 arr = Eigen::ArrayXXf::Zero(1, size0 * size1);
0046
0047 for (const ActsExamples::Cluster::Cell& cell : cluster.channels) {
0048
0049 int iMat = cell.bin[0] - offset0;
0050 int jMat = cell.bin[1] - offset1;
0051 if (iMat >= 0 && iMat < (int)size0 && jMat >= 0 && jMat < (int)size1) {
0052 typename Array::Index index = iMat * size0 + jMat;
0053 if (index < arr.size()) {
0054 arr(index) = cell.activation;
0055 }
0056 }
0057 }
0058 return size0 * size1;
0059 }
0060
0061 }
0062
0063 ActsExamples::NeuralCalibrator::NeuralCalibrator(
0064 const std::filesystem::path& modelPath, std::size_t nComponents,
0065 std::vector<std::size_t> volumeIds)
0066 : m_env(ORT_LOGGING_LEVEL_WARNING, "NeuralCalibrator"),
0067 m_model(m_env, modelPath.c_str()),
0068 m_nComponents{nComponents},
0069 m_volumeIds{std::move(volumeIds)} {}
0070
0071 void ActsExamples::NeuralCalibrator::calibrate(
0072 const MeasurementContainer& measurements, const ClusterContainer* clusters,
0073 const Acts::GeometryContext& gctx, const Acts::CalibrationContext& cctx,
0074 const Acts::SourceLink& sourceLink,
0075 Acts::MultiTrajectory<Acts::VectorMultiTrajectory>::TrackStateProxy&
0076 trackState) const {
0077 trackState.setUncalibratedSourceLink(sourceLink);
0078 const IndexSourceLink& idxSourceLink = sourceLink.get<IndexSourceLink>();
0079 assert((idxSourceLink.index() < measurements.size()) and
0080 "Source link index is outside the container bounds");
0081
0082 if (std::find(m_volumeIds.begin(), m_volumeIds.end(),
0083 idxSourceLink.geometryId().volume()) == m_volumeIds.end()) {
0084 m_fallback.calibrate(measurements, clusters, gctx, cctx, sourceLink,
0085 trackState);
0086 return;
0087 }
0088
0089 Acts::NetworkBatchInput inputBatch(1, m_nInputs);
0090 auto input = inputBatch(0, Eigen::all);
0091
0092
0093 std::size_t matSize0 = 7u;
0094 std::size_t matSize1 = 7u;
0095 std::size_t iInput = ::detail::fillChargeMatrix(
0096 input, (*clusters)[idxSourceLink.index()], matSize0, matSize1);
0097
0098 input[iInput++] = idxSourceLink.geometryId().volume();
0099 input[iInput++] = idxSourceLink.geometryId().layer();
0100
0101 const Acts::Surface& referenceSurface = trackState.referenceSurface();
0102
0103 std::visit(
0104 [&](const auto& measurement) {
0105 auto E = measurement.expander();
0106 auto P = measurement.projector();
0107 Acts::ActsVector<Acts::eBoundSize> fpar = E * measurement.parameters();
0108 Acts::ActsSquareMatrix<Acts::eBoundSize> fcov =
0109 E * measurement.covariance() * E.transpose();
0110
0111 Acts::Vector3 dir = Acts::makeDirectionFromPhiTheta(
0112 fpar[Acts::eBoundPhi], fpar[Acts::eBoundTheta]);
0113 Acts::Vector3 globalPosition = referenceSurface.localToGlobal(
0114 gctx, fpar.segment<2>(Acts::eBoundLoc0), dir);
0115
0116
0117
0118
0119
0120
0121 Acts::RotationMatrix3 rot =
0122 referenceSurface.referenceFrame(gctx, globalPosition, dir)
0123 .inverse();
0124 std::pair<double, double> angles =
0125 Acts::VectorHelpers::incidentAngles(dir, rot);
0126
0127 input[iInput++] = angles.first;
0128 input[iInput++] = angles.second;
0129 input[iInput++] = fpar[Acts::eBoundLoc0];
0130 input[iInput++] = fpar[Acts::eBoundLoc1];
0131 input[iInput++] = fcov(Acts::eBoundLoc0, Acts::eBoundLoc0);
0132 input[iInput++] = fcov(Acts::eBoundLoc1, Acts::eBoundLoc1);
0133 if (iInput != m_nInputs) {
0134 throw std::runtime_error("Expected input size of " +
0135 std::to_string(m_nInputs) +
0136 ", got: " + std::to_string(iInput));
0137 }
0138
0139
0140 std::vector<float> output =
0141 m_model.runONNXInference(inputBatch).front();
0142
0143
0144
0145
0146 std::size_t nParams = 5 * m_nComponents;
0147 if (output.size() != nParams) {
0148 throw std::runtime_error(
0149 "Got output vector of size " + std::to_string(output.size()) +
0150 ", expected size " + std::to_string(nParams));
0151 }
0152
0153
0154 std::size_t iMax = 0;
0155 if (m_nComponents > 1) {
0156 iMax = std::distance(
0157 output.begin(),
0158 std::max_element(output.begin(), output.begin() + m_nComponents));
0159 }
0160 std::size_t iLoc0 = m_nComponents + iMax * 2;
0161 std::size_t iVar0 = 3 * m_nComponents + iMax * 2;
0162
0163 fpar[Acts::eBoundLoc0] = output[iLoc0];
0164 fpar[Acts::eBoundLoc1] = output[iLoc0 + 1];
0165 fcov(Acts::eBoundLoc0, Acts::eBoundLoc0) = output[iVar0];
0166 fcov(Acts::eBoundLoc1, Acts::eBoundLoc1) = output[iVar0 + 1];
0167
0168 constexpr std::size_t kSize =
0169 std::remove_reference_t<decltype(measurement)>::size();
0170 std::array<Acts::BoundIndices, kSize> indices = measurement.indices();
0171 Acts::ActsVector<kSize> cpar = P * fpar;
0172 Acts::ActsSquareMatrix<kSize> ccov = P * fcov * P.transpose();
0173
0174 Acts::SourceLink sl{idxSourceLink};
0175
0176 Acts::Measurement<Acts::BoundIndices, kSize> calibrated(
0177 std::move(sl), indices, cpar, ccov);
0178
0179 trackState.allocateCalibrated(calibrated.size());
0180 trackState.setCalibrated(calibrated);
0181 },
0182 measurements[idxSourceLink.index()]);
0183 }