Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-06 08:11:13

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2020-2021 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Local include(s).
0012 #include "Acts/Plugins/Sycl/Seeding/detail/Types.hpp"
0013 
0014 #include "../Utilities/Arrays.hpp"
0015 #include "SpacePointType.hpp"
0016 
0017 // VecMem include(s).
0018 #include "vecmem/containers/data/jagged_vector_view.hpp"
0019 #include "vecmem/containers/data/vector_view.hpp"
0020 #include "vecmem/containers/device_vector.hpp"
0021 #include "vecmem/containers/jagged_device_vector.hpp"
0022 #include "vecmem/memory/atomic.hpp"
0023 
0024 // SYCL include(s).
0025 #include <CL/sycl.hpp>
0026 
0027 // System include(s).
0028 #include <cstdint>
0029 
0030 namespace Acts::Sycl::detail {
0031 
0032 /// Functor performing Triplet Search
0033 class TripletSearch {
0034  public:
0035   /// Constructor
0036   TripletSearch(
0037       vecmem::data::vector_view<uint32_t> sumBotTopCombView,
0038       const uint32_t numTripletSearchThreads, const uint32_t firstMiddle,
0039       const uint32_t lastMiddle,
0040       vecmem::data::jagged_vector_view<const uint32_t> midTopDupletView,
0041       vecmem::data::vector_view<uint32_t> sumBotMidView,
0042       vecmem::data::vector_view<uint32_t> sumTopMidView,
0043       vecmem::data::vector_view<detail::DeviceLinEqCircle> linearBotView,
0044       vecmem::data::vector_view<detail::DeviceLinEqCircle> linearTopView,
0045       vecmem::data::vector_view<const detail::DeviceSpacePoint> middleSPsView,
0046       vecmem::data::vector_view<uint32_t> indTopDupletview,
0047       vecmem::data::vector_view<uint32_t> countTripletsView,
0048       const DeviceSeedFinderConfig& config,
0049       vecmem::data::vector_view<detail::DeviceTriplet> curvImpactView)
0050       : m_sumBotTopCombView(sumBotTopCombView),
0051         m_numTripletSearchThreads(numTripletSearchThreads),
0052         m_firstMiddle(firstMiddle),
0053         m_lastMiddle(lastMiddle),
0054         m_midTopDupletView(midTopDupletView),
0055         m_sumBotMidView(sumBotMidView),
0056         m_sumTopMidView(sumTopMidView),
0057         m_linearBotView(linearBotView),
0058         m_linearTopView(linearTopView),
0059         m_middleSPsView(middleSPsView),
0060         m_indTopDupletView(indTopDupletview),
0061         m_countTripletsView(countTripletsView),
0062         m_config(config),
0063         m_curvImpactView(curvImpactView) {}
0064 
0065   /// Operator performing the triplet search
0066   void operator()(cl::sycl::nd_item<1> item) const {
0067     // Get the index
0068     const uint32_t idx = item.get_global_linear_id();
0069     if (idx < m_numTripletSearchThreads) {
0070       // Retrieve the index of the corresponding middle
0071       // space point by binary search
0072       vecmem::device_vector<uint32_t> sumBotTopCombPrefix(m_sumBotTopCombView);
0073       const auto sumCombUptoFirstMiddle = sumBotTopCombPrefix[m_firstMiddle];
0074       auto L = m_firstMiddle;
0075       auto R = m_lastMiddle;
0076       auto mid = L;
0077       while (L < R - 1) {
0078         mid = (L + R) / 2;
0079         // To be able to search in sumBotTopCombPrefix, we need
0080         // to use an offset (sumCombUptoFirstMiddle).
0081         if (idx + sumCombUptoFirstMiddle < sumBotTopCombPrefix[mid]) {
0082           R = mid;
0083         } else {
0084           L = mid;
0085         }
0086       }
0087       mid = L;
0088       vecmem::jagged_device_vector<const uint32_t> midTopDuplets(
0089           m_midTopDupletView);
0090       const auto numT = midTopDuplets.at(mid).size();
0091       const auto threadIdxForMiddleSP =
0092           (idx - sumBotTopCombPrefix[mid] + sumCombUptoFirstMiddle);
0093       /*
0094           NOTES ON THREAD MAPPING TO SPACE POINTS
0095 
0096           We need to map bottom and top SP indices to this
0097           thread.
0098 
0099           So we are mapping one bottom and one top SP to this thread
0100           (we already have a middle SP) which gives us a tiplet.
0101 
0102           This is done in the following way: We
0103           calculated the number of possible triplet
0104           combinations for this middle SP (let it be
0105           num_comp_bot*num_comp_top). Let num_comp_bot = 2
0106           and num_comp_top=3 in this example. So we have 2
0107           compatible bottom and 3 compatible top SP for this
0108           middle SP.
0109 
0110           That gives us 6 threads altogether:
0111                       ===========================================
0112           thread:    |  0   |  1   |  2   |  3   |  4   |  5   |
0113           bottom id: | bot0 | bot0 | bot0 | bot1 | bot1 | bot1 |
0114           top id:    | top0 | top1 | top2 | top0 | top1 | top2 |
0115                       ===========================================
0116 
0117           If we divide 6 by the number of compatible top SP
0118           for this middle SP, or deviceNumTopDuplets[mid]
0119           which is 3 now, we get the id for the bottom SP.
0120           Similarly, if we take modulo
0121           deviceNumTopDuplets[mid], we get the id for the
0122           top SP.
0123 
0124           So if threadIdxForMiddleSP = 3, then ib = 1 and it = 0.
0125 
0126           We can use these ids together with
0127           sumBotMidPrefix[mid] and deviceSumTop[mid] to be able
0128           to index our other arrays.
0129 
0130           These other arrays are deviceIndBot and deviceIndTop.
0131 
0132           So to retrieve the bottom SP index for this thread, we'd
0133           have to index the deviceIndBot array at
0134               sumBotMidPrefix[mid] + ib
0135           which is the id for the bottom SP that we just calculated
0136           (ib = 1 in the example).
0137       */
0138 
0139       vecmem::device_vector<uint32_t> sumBotMidPrefix(m_sumBotMidView),
0140           sumTopMidPrefix(m_sumTopMidView);
0141       const auto ib = sumBotMidPrefix[mid] + (threadIdxForMiddleSP / numT);
0142       const auto it = sumTopMidPrefix[mid] + (threadIdxForMiddleSP % numT);
0143       vecmem::device_vector<detail::DeviceLinEqCircle> deviceLinBot(
0144           m_linearBotView),
0145           deviceLinTop(m_linearTopView);
0146       const auto linBotEq = deviceLinBot[ib];
0147       const auto linTopEq = deviceLinTop[it];
0148       const vecmem::device_vector<const detail::DeviceSpacePoint> middleSPs(
0149           m_middleSPsView);
0150       const auto midSP = middleSPs[mid];
0151 
0152       const auto Vb = linBotEq.v;
0153       const auto Ub = linBotEq.u;
0154       const auto Erb = linBotEq.er;
0155       const auto cotThetab = linBotEq.cotTheta;
0156       const auto iDeltaRb = linBotEq.iDeltaR;
0157 
0158       const auto Vt = linTopEq.v;
0159       const auto Ut = linTopEq.u;
0160       const auto Ert = linTopEq.er;
0161       const auto cotThetat = linTopEq.cotTheta;
0162       const auto iDeltaRt = linTopEq.iDeltaR;
0163 
0164       const auto rM = midSP.r;
0165       const auto varianceRM = midSP.varR;
0166       const auto varianceZM = midSP.varZ;
0167 
0168       auto iSinTheta2 = (1.f + cotThetab * cotThetab);
0169       auto scatteringInRegion2 = m_config.maxScatteringAngle2 * iSinTheta2;
0170       scatteringInRegion2 *=
0171           m_config.sigmaScattering * m_config.sigmaScattering;
0172       auto error2 = Ert + Erb +
0173                     2.f * (cotThetab * cotThetat * varianceRM + varianceZM) *
0174                         iDeltaRb * iDeltaRt;
0175       auto deltaCotTheta = cotThetab - cotThetat;
0176       auto deltaCotTheta2 = deltaCotTheta * deltaCotTheta;
0177 
0178       deltaCotTheta = cl::sycl::abs(deltaCotTheta);
0179       auto error = cl::sycl::sqrt(error2);
0180       auto dCotThetaMinusError2 =
0181           deltaCotTheta2 + error2 - 2.f * deltaCotTheta * error;
0182       auto dU = Ut - Ub;
0183 
0184       if ((!(deltaCotTheta2 - error2 > 0.f) ||
0185            !(dCotThetaMinusError2 > scatteringInRegion2)) &&
0186           !(dU == 0.f)) {
0187         auto A = (Vt - Vb) / dU;
0188         auto S2 = 1.f + A * A;
0189         auto B = Vb - A * Ub;
0190         auto B2 = B * B;
0191 
0192         auto iHelixDiameter2 = B2 / S2;
0193         auto pT2scatter = 4.f * iHelixDiameter2 * m_config.pT2perRadius;
0194         auto p2scatter = pT2scatter * iSinTheta2;
0195         auto Im = cl::sycl::abs((A - B * rM) * rM);
0196 
0197         if (!(S2 < B2 * m_config.minHelixDiameter2) &&
0198             !((deltaCotTheta2 - error2 > 0.f) &&
0199               (dCotThetaMinusError2 > p2scatter * m_config.sigmaScattering *
0200                                           m_config.sigmaScattering)) &&
0201             !(Im > m_config.impactMax)) {
0202           vecmem::device_vector<uint32_t> deviceIndTopDuplets(
0203               m_indTopDupletView);
0204           const auto top = deviceIndTopDuplets[it];
0205           // this will be the t-th top space point for
0206           // fixed middle and bottom SP
0207           vecmem::device_vector<uint32_t> deviceCountTriplets(
0208               m_countTripletsView);
0209           vecmem::atomic obj(&deviceCountTriplets[ib]);
0210           auto t = obj.fetch_add(1);
0211           /*
0212               sumBotTopCombPrefix[mid] - sumCombUptoFirstMiddle:
0213               gives the memory location reserved for this
0214               middle SP
0215 
0216               (idx-sumBotTopCombPrefix[mid]+sumCombUptoFirstMiddle:
0217               this is the nth thread for this middle SP
0218 
0219               (idx-sumBotTopCombPrefix[mid]+sumCombUptoFirstMiddle)/numT:
0220               this is the mth bottom SP for this middle SP
0221 
0222               multiplying this by numT gives the memory
0223               location for this middle and bottom SP
0224 
0225               and by adding t to it, we will end up storing
0226               compatible triplet candidates for this middle
0227               and bottom SP right next to each other
0228               starting from the given memory
0229           */
0230           const auto tripletIdx =
0231               sumBotTopCombPrefix[mid] - sumCombUptoFirstMiddle +
0232               (((idx - sumBotTopCombPrefix[mid] + sumCombUptoFirstMiddle) /
0233                 numT) *
0234                numT) +
0235               t;
0236 
0237           detail::DeviceTriplet T;
0238           T.curvature = B / cl::sycl::sqrt(S2);
0239           T.impact = Im;
0240           T.topSPIndex = top;
0241           vecmem::device_vector<detail::DeviceTriplet> deviceCurvImpact(
0242               m_curvImpactView);
0243           deviceCurvImpact[tripletIdx] = T;
0244         }
0245       }
0246     }
0247   }
0248 
0249  private:
0250   vecmem::data::vector_view<uint32_t> m_sumBotTopCombView;
0251   const uint32_t m_numTripletSearchThreads;
0252   const uint32_t m_firstMiddle;
0253   const u_int32_t m_lastMiddle;
0254   vecmem::data::jagged_vector_view<const uint32_t> m_midTopDupletView;
0255   vecmem::data::vector_view<uint32_t> m_sumBotMidView;
0256   vecmem::data::vector_view<uint32_t> m_sumTopMidView;
0257   vecmem::data::vector_view<detail::DeviceLinEqCircle> m_linearBotView;
0258   vecmem::data::vector_view<detail::DeviceLinEqCircle> m_linearTopView;
0259   vecmem::data::vector_view<const detail::DeviceSpacePoint> m_middleSPsView;
0260   vecmem::data::vector_view<uint32_t> m_indTopDupletView;
0261   vecmem::data::vector_view<uint32_t> m_countTripletsView;
0262   DeviceSeedFinderConfig m_config;
0263   vecmem::data::vector_view<detail::DeviceTriplet> m_curvImpactView;
0264 };  // struct TripletSearch
0265 
0266 }  // namespace Acts::Sycl::detail