File indexing completed on 2025-08-05 08:09:45
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Digitization/ModuleClusters.hpp"
0010
0011 #include "Acts/Clusterization/Clusterization.hpp"
0012 #include "ActsExamples/Digitization/MeasurementCreation.hpp"
0013 #include "ActsFatras/Digitization/Channelizer.hpp"
0014
0015 #include <array>
0016 #include <cmath>
0017 #include <cstdint>
0018 #include <cstdlib>
0019 #include <memory>
0020 #include <stdexcept>
0021 #include <type_traits>
0022
0023 namespace ActsExamples {
0024
0025 void ModuleClusters::add(DigitizedParameters params, simhit_t simhit) {
0026 ModuleValue mval;
0027 mval.paramIndices = std::move(params.indices);
0028 mval.paramValues = std::move(params.values);
0029 mval.paramVariances = std::move(params.variances);
0030 mval.sources = {simhit};
0031
0032 if (m_merge && !params.cluster.channels.empty()) {
0033
0034 for (auto cell : params.cluster.channels) {
0035 ModuleValue mval_cell = mval;
0036 mval_cell.value = cell;
0037 m_moduleValues.push_back(std::move(mval_cell));
0038 }
0039 } else {
0040
0041 mval.value = std::move(params.cluster);
0042 m_moduleValues.push_back(std::move(mval));
0043 }
0044 }
0045
0046 std::vector<std::pair<DigitizedParameters, std::set<ModuleClusters::simhit_t>>>
0047 ModuleClusters::digitizedParameters() {
0048 if (m_merge) {
0049 merge();
0050 }
0051 std::vector<std::pair<DigitizedParameters, std::set<simhit_t>>> retv;
0052 for (ModuleValue& mval : m_moduleValues) {
0053 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0054
0055
0056
0057 throw std::runtime_error("Invalid cluster!");
0058 }
0059 DigitizedParameters dpars;
0060 dpars.indices = mval.paramIndices;
0061 dpars.values = mval.paramValues;
0062 dpars.variances = mval.paramVariances;
0063 dpars.cluster = std::get<Cluster>(mval.value);
0064 retv.emplace_back(std::move(dpars), mval.sources);
0065 }
0066 return retv;
0067 }
0068
0069
0070 int getCellRow(const ModuleValue& mval) {
0071 if (std::holds_alternative<ActsExamples::Cluster::Cell>(mval.value)) {
0072 return std::get<ActsExamples::Cluster::Cell>(mval.value).bin[0];
0073 }
0074 throw std::domain_error("ModuleValue does not contain cell!");
0075 }
0076
0077 int getCellColumn(const ActsExamples::ModuleValue& mval) {
0078 if (std::holds_alternative<ActsExamples::Cluster::Cell>(mval.value)) {
0079 return std::get<ActsExamples::Cluster::Cell>(mval.value).bin[1];
0080 }
0081 throw std::domain_error("ModuleValue does not contain cell!");
0082 }
0083
0084 int& getCellLabel(ActsExamples::ModuleValue& mval) {
0085 return mval.label;
0086 }
0087
0088 void clusterAddCell(std::vector<ModuleValue>& cl, const ModuleValue& ce) {
0089 cl.push_back(ce);
0090 }
0091
0092 std::vector<ModuleValue> ModuleClusters::createCellCollection() {
0093 std::vector<ModuleValue> cells;
0094 for (ModuleValue& mval : m_moduleValues) {
0095 if (std::holds_alternative<Cluster::Cell>(mval.value)) {
0096 cells.push_back(mval);
0097 }
0098 }
0099 return cells;
0100 }
0101
0102 void ModuleClusters::merge() {
0103 std::vector<ModuleValue> cells = createCellCollection();
0104
0105 std::vector<ModuleValue> newVals;
0106
0107 if (!cells.empty()) {
0108
0109 std::vector<std::vector<ModuleValue>> merged =
0110 Acts::Ccl::createClusters<std::vector<ModuleValue>,
0111 std::vector<std::vector<ModuleValue>>>(
0112 cells, Acts::Ccl::DefaultConnect<ModuleValue>(m_commonCorner));
0113
0114 for (std::vector<ModuleValue>& cellv : merged) {
0115
0116
0117
0118
0119
0120
0121 for (std::vector<ModuleValue>& remerged : mergeParameters(cellv)) {
0122 newVals.push_back(squash(remerged));
0123 }
0124 }
0125 m_moduleValues = std::move(newVals);
0126 } else {
0127
0128 for (std::vector<ModuleValue>& merged : mergeParameters(m_moduleValues)) {
0129 newVals.push_back(squash(merged));
0130 }
0131 m_moduleValues = std::move(newVals);
0132 }
0133 }
0134
0135
0136 std::vector<std::size_t> ModuleClusters::nonGeoEntries(
0137 std::vector<Acts::BoundIndices>& indices) {
0138 std::vector<std::size_t> retv;
0139 for (std::size_t i = 0; i < indices.size(); i++) {
0140 auto idx = indices.at(i);
0141 if (std::find(m_geoIndices.begin(), m_geoIndices.end(), idx) ==
0142 m_geoIndices.end()) {
0143 retv.push_back(i);
0144 }
0145 }
0146 return retv;
0147 }
0148
0149
0150 std::vector<std::vector<ModuleValue>> ModuleClusters::mergeParameters(
0151 std::vector<ModuleValue> values) {
0152 std::vector<std::vector<ModuleValue>> retv;
0153
0154 std::vector<bool> used(values.size(), false);
0155 for (std::size_t i = 0; i < values.size(); i++) {
0156 if (used.at(i)) {
0157 continue;
0158 }
0159
0160 retv.emplace_back();
0161 std::vector<ModuleValue>& thisvec = retv.back();
0162
0163
0164 thisvec.push_back(std::move(values.at(i)));
0165 used.at(i) = true;
0166
0167
0168
0169
0170 for (std::size_t j = i + 1; j < values.size(); j++) {
0171
0172 if (used.at(j)) {
0173 continue;
0174 }
0175
0176
0177
0178
0179
0180 bool matched = true;
0181
0182
0183
0184
0185 for (ModuleValue& thisval : thisvec) {
0186
0187 for (auto k : nonGeoEntries(thisval.paramIndices)) {
0188 Acts::ActsScalar p_i = thisval.paramValues.at(k);
0189 Acts::ActsScalar p_j = values.at(j).paramValues.at(k);
0190 Acts::ActsScalar v_i = thisval.paramVariances.at(k);
0191 Acts::ActsScalar v_j = values.at(j).paramVariances.at(k);
0192
0193 Acts::ActsScalar left = 0, right = 0;
0194 if (p_i < p_j) {
0195 left = p_i + m_nsigma * std::sqrt(v_i);
0196 right = p_j - m_nsigma * std::sqrt(v_j);
0197 } else {
0198 left = p_j + m_nsigma * std::sqrt(v_j);
0199 right = p_i - m_nsigma * std::sqrt(v_i);
0200 }
0201 if (left < right) {
0202
0203
0204 matched = false;
0205 break;
0206 }
0207 }
0208 if (matched) {
0209
0210
0211
0212 break;
0213 }
0214 }
0215 if (matched) {
0216
0217 used.at(j) = true;
0218 thisvec.push_back(std::move(values.at(j)));
0219 }
0220 }
0221 }
0222 return retv;
0223 }
0224
0225 ModuleValue ModuleClusters::squash(std::vector<ModuleValue>& values) {
0226 ModuleValue mval;
0227 Acts::ActsScalar tot = 0;
0228 Acts::ActsScalar tot2 = 0;
0229 std::vector<Acts::ActsScalar> weights;
0230
0231
0232 for (ModuleValue& other : values) {
0233 if (std::holds_alternative<Cluster::Cell>(other.value)) {
0234 weights.push_back(std::get<Cluster::Cell>(other.value).activation);
0235 } else {
0236 weights.push_back(1);
0237 }
0238 tot += weights.back();
0239 tot2 += weights.back() * weights.back();
0240 }
0241
0242
0243 for (std::size_t i = 0; i < values.size(); i++) {
0244 ModuleValue& other = values.at(i);
0245 for (std::size_t j = 0; j < other.paramIndices.size(); j++) {
0246 auto idx = other.paramIndices.at(j);
0247 if (std::find(m_geoIndices.begin(), m_geoIndices.end(), idx) ==
0248 m_geoIndices.end()) {
0249 if (std::find(mval.paramIndices.begin(), mval.paramIndices.end(),
0250 idx) == mval.paramIndices.end()) {
0251 mval.paramIndices.push_back(idx);
0252 }
0253 if (mval.paramValues.size() < (j + 1)) {
0254 mval.paramValues.push_back(0);
0255 mval.paramVariances.push_back(0);
0256 }
0257 Acts::ActsScalar f = weights.at(i) / (tot > 0 ? tot : 1);
0258 Acts::ActsScalar f2 =
0259 weights.at(i) * weights.at(i) / (tot2 > 0 ? tot2 : 1);
0260 mval.paramValues.at(j) += f * other.paramValues.at(j);
0261 mval.paramVariances.at(j) += f2 * other.paramVariances.at(j);
0262 }
0263 }
0264 }
0265
0266
0267 Cluster clus;
0268
0269 const auto& binningData = m_segmentation.binningData();
0270 Acts::Vector2 pos(0., 0.);
0271 Acts::Vector2 var(0., 0.);
0272
0273 std::size_t b0min = SIZE_MAX;
0274 std::size_t b0max = 0;
0275 std::size_t b1min = SIZE_MAX;
0276 std::size_t b1max = 0;
0277
0278 for (std::size_t i = 0; i < values.size(); i++) {
0279 ModuleValue& other = values.at(i);
0280 if (!std::holds_alternative<Cluster::Cell>(other.value)) {
0281 continue;
0282 }
0283
0284 Cluster::Cell ch = std::get<Cluster::Cell>(other.value);
0285 auto bin = ch.bin;
0286
0287 std::size_t b0 = bin[0];
0288 std::size_t b1 = bin[1];
0289
0290 b0min = std::min(b0min, b0);
0291 b0max = std::max(b0max, b0);
0292 b1min = std::min(b1min, b1);
0293 b1max = std::max(b1max, b1);
0294
0295 float p0 = binningData[0].center(b0);
0296 float w0 = binningData[0].width(b0);
0297 float p1 = binningData[1].center(b1);
0298 float w1 = binningData[1].width(b1);
0299
0300 pos += Acts::Vector2(weights.at(i) * p0, weights.at(i) * p1);
0301
0302
0303
0304 var += Acts::Vector2(weights.at(i) * weights.at(i) * w0 * w0 / 12,
0305 weights.at(i) * weights.at(i) * w1 * w1 / 12);
0306
0307 clus.channels.push_back(std::move(ch));
0308
0309
0310
0311 clus.sizeLoc0 = b0max - b0min + 1;
0312 clus.sizeLoc1 = b1max - b1min + 1;
0313 }
0314
0315 if (tot > 0) {
0316 pos /= tot;
0317 var /= (tot * tot);
0318 }
0319
0320 for (auto idx : m_geoIndices) {
0321 mval.paramIndices.push_back(idx);
0322 mval.paramValues.push_back(pos[idx]);
0323 mval.paramVariances.push_back(var[idx]);
0324 }
0325
0326 mval.value = std::move(clus);
0327
0328
0329 for (ModuleValue& other : values) {
0330 mval.sources.merge(other.sources);
0331 }
0332
0333 return mval;
0334 }
0335
0336 }