File indexing completed on 2025-08-06 08:11:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "Acts/Plugins/Sycl/Seeding/detail/Types.hpp"
0013
0014 #include "../Utilities/Arrays.hpp"
0015 #include "SpacePointType.hpp"
0016
0017
0018 #include "vecmem/containers/data/jagged_vector_view.hpp"
0019 #include "vecmem/containers/data/vector_view.hpp"
0020 #include "vecmem/containers/device_vector.hpp"
0021 #include "vecmem/containers/jagged_device_vector.hpp"
0022 #include "vecmem/memory/atomic.hpp"
0023
0024
0025 #include <CL/sycl.hpp>
0026
0027
0028 #include <cstdint>
0029
0030 namespace Acts::Sycl::detail {
0031
0032
0033 class TripletSearch {
0034 public:
0035
0036 TripletSearch(
0037 vecmem::data::vector_view<uint32_t> sumBotTopCombView,
0038 const uint32_t numTripletSearchThreads, const uint32_t firstMiddle,
0039 const uint32_t lastMiddle,
0040 vecmem::data::jagged_vector_view<const uint32_t> midTopDupletView,
0041 vecmem::data::vector_view<uint32_t> sumBotMidView,
0042 vecmem::data::vector_view<uint32_t> sumTopMidView,
0043 vecmem::data::vector_view<detail::DeviceLinEqCircle> linearBotView,
0044 vecmem::data::vector_view<detail::DeviceLinEqCircle> linearTopView,
0045 vecmem::data::vector_view<const detail::DeviceSpacePoint> middleSPsView,
0046 vecmem::data::vector_view<uint32_t> indTopDupletview,
0047 vecmem::data::vector_view<uint32_t> countTripletsView,
0048 const DeviceSeedFinderConfig& config,
0049 vecmem::data::vector_view<detail::DeviceTriplet> curvImpactView)
0050 : m_sumBotTopCombView(sumBotTopCombView),
0051 m_numTripletSearchThreads(numTripletSearchThreads),
0052 m_firstMiddle(firstMiddle),
0053 m_lastMiddle(lastMiddle),
0054 m_midTopDupletView(midTopDupletView),
0055 m_sumBotMidView(sumBotMidView),
0056 m_sumTopMidView(sumTopMidView),
0057 m_linearBotView(linearBotView),
0058 m_linearTopView(linearTopView),
0059 m_middleSPsView(middleSPsView),
0060 m_indTopDupletView(indTopDupletview),
0061 m_countTripletsView(countTripletsView),
0062 m_config(config),
0063 m_curvImpactView(curvImpactView) {}
0064
0065
0066 void operator()(cl::sycl::nd_item<1> item) const {
0067
0068 const uint32_t idx = item.get_global_linear_id();
0069 if (idx < m_numTripletSearchThreads) {
0070
0071
0072 vecmem::device_vector<uint32_t> sumBotTopCombPrefix(m_sumBotTopCombView);
0073 const auto sumCombUptoFirstMiddle = sumBotTopCombPrefix[m_firstMiddle];
0074 auto L = m_firstMiddle;
0075 auto R = m_lastMiddle;
0076 auto mid = L;
0077 while (L < R - 1) {
0078 mid = (L + R) / 2;
0079
0080
0081 if (idx + sumCombUptoFirstMiddle < sumBotTopCombPrefix[mid]) {
0082 R = mid;
0083 } else {
0084 L = mid;
0085 }
0086 }
0087 mid = L;
0088 vecmem::jagged_device_vector<const uint32_t> midTopDuplets(
0089 m_midTopDupletView);
0090 const auto numT = midTopDuplets.at(mid).size();
0091 const auto threadIdxForMiddleSP =
0092 (idx - sumBotTopCombPrefix[mid] + sumCombUptoFirstMiddle);
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139 vecmem::device_vector<uint32_t> sumBotMidPrefix(m_sumBotMidView),
0140 sumTopMidPrefix(m_sumTopMidView);
0141 const auto ib = sumBotMidPrefix[mid] + (threadIdxForMiddleSP / numT);
0142 const auto it = sumTopMidPrefix[mid] + (threadIdxForMiddleSP % numT);
0143 vecmem::device_vector<detail::DeviceLinEqCircle> deviceLinBot(
0144 m_linearBotView),
0145 deviceLinTop(m_linearTopView);
0146 const auto linBotEq = deviceLinBot[ib];
0147 const auto linTopEq = deviceLinTop[it];
0148 const vecmem::device_vector<const detail::DeviceSpacePoint> middleSPs(
0149 m_middleSPsView);
0150 const auto midSP = middleSPs[mid];
0151
0152 const auto Vb = linBotEq.v;
0153 const auto Ub = linBotEq.u;
0154 const auto Erb = linBotEq.er;
0155 const auto cotThetab = linBotEq.cotTheta;
0156 const auto iDeltaRb = linBotEq.iDeltaR;
0157
0158 const auto Vt = linTopEq.v;
0159 const auto Ut = linTopEq.u;
0160 const auto Ert = linTopEq.er;
0161 const auto cotThetat = linTopEq.cotTheta;
0162 const auto iDeltaRt = linTopEq.iDeltaR;
0163
0164 const auto rM = midSP.r;
0165 const auto varianceRM = midSP.varR;
0166 const auto varianceZM = midSP.varZ;
0167
0168 auto iSinTheta2 = (1.f + cotThetab * cotThetab);
0169 auto scatteringInRegion2 = m_config.maxScatteringAngle2 * iSinTheta2;
0170 scatteringInRegion2 *=
0171 m_config.sigmaScattering * m_config.sigmaScattering;
0172 auto error2 = Ert + Erb +
0173 2.f * (cotThetab * cotThetat * varianceRM + varianceZM) *
0174 iDeltaRb * iDeltaRt;
0175 auto deltaCotTheta = cotThetab - cotThetat;
0176 auto deltaCotTheta2 = deltaCotTheta * deltaCotTheta;
0177
0178 deltaCotTheta = cl::sycl::abs(deltaCotTheta);
0179 auto error = cl::sycl::sqrt(error2);
0180 auto dCotThetaMinusError2 =
0181 deltaCotTheta2 + error2 - 2.f * deltaCotTheta * error;
0182 auto dU = Ut - Ub;
0183
0184 if ((!(deltaCotTheta2 - error2 > 0.f) ||
0185 !(dCotThetaMinusError2 > scatteringInRegion2)) &&
0186 !(dU == 0.f)) {
0187 auto A = (Vt - Vb) / dU;
0188 auto S2 = 1.f + A * A;
0189 auto B = Vb - A * Ub;
0190 auto B2 = B * B;
0191
0192 auto iHelixDiameter2 = B2 / S2;
0193 auto pT2scatter = 4.f * iHelixDiameter2 * m_config.pT2perRadius;
0194 auto p2scatter = pT2scatter * iSinTheta2;
0195 auto Im = cl::sycl::abs((A - B * rM) * rM);
0196
0197 if (!(S2 < B2 * m_config.minHelixDiameter2) &&
0198 !((deltaCotTheta2 - error2 > 0.f) &&
0199 (dCotThetaMinusError2 > p2scatter * m_config.sigmaScattering *
0200 m_config.sigmaScattering)) &&
0201 !(Im > m_config.impactMax)) {
0202 vecmem::device_vector<uint32_t> deviceIndTopDuplets(
0203 m_indTopDupletView);
0204 const auto top = deviceIndTopDuplets[it];
0205
0206
0207 vecmem::device_vector<uint32_t> deviceCountTriplets(
0208 m_countTripletsView);
0209 vecmem::atomic obj(&deviceCountTriplets[ib]);
0210 auto t = obj.fetch_add(1);
0211
0212
0213
0214
0215
0216
0217
0218
0219
0220
0221
0222
0223
0224
0225
0226
0227
0228
0229
0230 const auto tripletIdx =
0231 sumBotTopCombPrefix[mid] - sumCombUptoFirstMiddle +
0232 (((idx - sumBotTopCombPrefix[mid] + sumCombUptoFirstMiddle) /
0233 numT) *
0234 numT) +
0235 t;
0236
0237 detail::DeviceTriplet T;
0238 T.curvature = B / cl::sycl::sqrt(S2);
0239 T.impact = Im;
0240 T.topSPIndex = top;
0241 vecmem::device_vector<detail::DeviceTriplet> deviceCurvImpact(
0242 m_curvImpactView);
0243 deviceCurvImpact[tripletIdx] = T;
0244 }
0245 }
0246 }
0247 }
0248
0249 private:
0250 vecmem::data::vector_view<uint32_t> m_sumBotTopCombView;
0251 const uint32_t m_numTripletSearchThreads;
0252 const uint32_t m_firstMiddle;
0253 const u_int32_t m_lastMiddle;
0254 vecmem::data::jagged_vector_view<const uint32_t> m_midTopDupletView;
0255 vecmem::data::vector_view<uint32_t> m_sumBotMidView;
0256 vecmem::data::vector_view<uint32_t> m_sumTopMidView;
0257 vecmem::data::vector_view<detail::DeviceLinEqCircle> m_linearBotView;
0258 vecmem::data::vector_view<detail::DeviceLinEqCircle> m_linearTopView;
0259 vecmem::data::vector_view<const detail::DeviceSpacePoint> m_middleSPsView;
0260 vecmem::data::vector_view<uint32_t> m_indTopDupletView;
0261 vecmem::data::vector_view<uint32_t> m_countTripletsView;
0262 DeviceSeedFinderConfig m_config;
0263 vecmem::data::vector_view<detail::DeviceTriplet> m_curvImpactView;
0264 };
0265
0266 }