Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-06 08:11:12

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2020-2021 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // Local include(s).
0012 #include "Acts/Plugins/Sycl/Seeding/detail/Types.hpp"
0013 
0014 #include "../Utilities/Arrays.hpp"
0015 #include "SpacePointType.hpp"
0016 
0017 // VecMem include(s).
0018 #include "vecmem/containers/data/jagged_vector_view.hpp"
0019 #include "vecmem/containers/data/vector_view.hpp"
0020 #include "vecmem/containers/device_vector.hpp"
0021 #include "vecmem/containers/jagged_device_vector.hpp"
0022 
0023 // SYCL include(s).
0024 #include <CL/sycl.hpp>
0025 
0026 // System include(s).
0027 #include <cstdint>
0028 
0029 namespace Acts::Sycl::detail {
0030 
0031 /// Functor performing Triplet Filter
0032 class TripletFilter {
0033  public:
0034   /// Constructor
0035   TripletFilter(
0036       const uint32_t numTripletFilterThreads,
0037       vecmem::data::vector_view<uint32_t> sumBotMidView,
0038       const uint32_t firstMiddle,
0039       vecmem::data::vector_view<uint32_t> indMidBotCompView,
0040       vecmem::data::vector_view<uint32_t> indBotDupletView,
0041       vecmem::data::vector_view<uint32_t> sumBotTopCombView,
0042       vecmem::data::jagged_vector_view<uint32_t> midTopDupletView,
0043       vecmem::data::vector_view<detail::DeviceTriplet> curvImpactView,
0044       vecmem::data::vector_view<const detail::DeviceSpacePoint> topSPsView,
0045       vecmem::data::vector_view<const detail::DeviceSpacePoint> middleSPsView,
0046       vecmem::data::vector_view<const detail::DeviceSpacePoint> bottomSPsView,
0047       vecmem::data::vector_view<uint32_t> countTripletsView,
0048       vecmem::data::vector_view<detail::SeedData> seedArrayView,
0049       const DeviceSeedFinderConfig& config, const DeviceExperimentCuts& cuts)
0050       : m_numTripletFilterThreads(numTripletFilterThreads),
0051         m_sumBotMidView(sumBotMidView),
0052         m_firstMiddle(firstMiddle),
0053         m_indMidBotCompView(indMidBotCompView),
0054         m_indBotDupletView(indBotDupletView),
0055         m_sumBotTopCombView(sumBotTopCombView),
0056         m_midTopDupletView(midTopDupletView),
0057         m_curvImpactView(curvImpactView),
0058         m_topSPsView(topSPsView),
0059         m_middleSPsView(middleSPsView),
0060         m_bottomSPsView(bottomSPsView),
0061         m_countTripletsView(countTripletsView),
0062         m_seedArrayView(seedArrayView),
0063         m_config(config),
0064         m_cuts(cuts) {}
0065 
0066   /// Operator performing filtering
0067   void operator()(cl::sycl::nd_item<1> item) const {
0068     if (item.get_global_linear_id() < m_numTripletFilterThreads) {
0069       vecmem::device_vector<uint32_t> sumBotMidPrefix(m_sumBotMidView),
0070           deviceIndMidBot(m_indMidBotCompView),
0071           deviceIndBotDuplets(m_indBotDupletView),
0072           sumBotTopCombPrefix(m_sumBotTopCombView),
0073           deviceCountTriplets(m_countTripletsView);
0074       vecmem::jagged_device_vector<uint32_t> midTopDuplets(m_midTopDupletView);
0075       const auto idx =
0076           sumBotMidPrefix[m_firstMiddle] + item.get_global_linear_id();
0077       const auto mid = deviceIndMidBot[idx];
0078       const auto bot = deviceIndBotDuplets[idx];
0079       const auto sumCombUptoFirstMiddle = sumBotTopCombPrefix[m_firstMiddle];
0080       const auto tripletBegin =
0081           sumBotTopCombPrefix[mid] - sumCombUptoFirstMiddle +
0082           (idx - sumBotMidPrefix[mid]) * midTopDuplets.at(mid).size();
0083       const auto tripletEnd = tripletBegin + deviceCountTriplets[idx];
0084       const vecmem::device_vector<detail::DeviceTriplet> deviceCurvImpactConst(
0085           m_curvImpactView);
0086       for (auto i1 = tripletBegin; i1 < tripletEnd; ++i1) {
0087         const auto current = deviceCurvImpactConst[i1];
0088         const auto top = current.topSPIndex;
0089 
0090         const auto invHelixDiameter = current.curvature;
0091         const auto lowerLimitCurv =
0092             invHelixDiameter - m_config.deltaInvHelixDiameter;
0093         const auto upperLimitCurv =
0094             invHelixDiameter + m_config.deltaInvHelixDiameter;
0095         const vecmem::device_vector<const detail::DeviceSpacePoint> topSPs(
0096             m_topSPsView);
0097         const auto currentTop_r = topSPs[top].r;
0098         auto weight = -(current.impact * m_config.impactWeightFactor);
0099 
0100         uint32_t compatCounter = 0;
0101         // By default compatSeedLimit is 2 -> 2 is
0102         // currently hard coded, because variable length
0103         // arrays are not supported in SYCL kernels.
0104         float compatibleSeedR[2];
0105         for (auto i2 = tripletBegin;
0106              i2 < tripletEnd && compatCounter < m_config.compatSeedLimit;
0107              ++i2) {
0108           const auto other = deviceCurvImpactConst[i2];
0109 
0110           const auto otherCurv = other.curvature;
0111           const auto otherTop_r = topSPs[other.topSPIndex].r;
0112           const float deltaR = cl::sycl::abs(currentTop_r - otherTop_r);
0113           if (deltaR >= m_config.filterDeltaRMin &&
0114               otherCurv >= lowerLimitCurv && otherCurv <= upperLimitCurv) {
0115             uint32_t c = 0;
0116             for (; c < compatCounter &&
0117                    cl::sycl::abs(compatibleSeedR[c] - otherTop_r) >=
0118                        m_config.filterDeltaRMin;
0119                  ++c) {
0120             }
0121             if (c == compatCounter) {
0122               compatibleSeedR[c] = otherTop_r;
0123               ++compatCounter;
0124             }
0125           }
0126         }
0127         weight += compatCounter * m_config.compatSeedWeight;
0128         const vecmem::device_vector<const detail::DeviceSpacePoint> middleSPs(
0129             m_middleSPsView),
0130             bottomSPs(m_bottomSPsView);
0131         const auto bottomSP = bottomSPs[bot];
0132         const auto middleSP = middleSPs[mid];
0133         const auto topSP = topSPs[top];
0134 
0135         weight += m_cuts.seedWeight(bottomSP, middleSP, topSP);
0136 
0137         if (m_cuts.singleSeedCut(weight, bottomSP, middleSP, topSP)) {
0138           detail::SeedData D;
0139           D.bottom = bot;
0140           D.top = top;
0141           D.middle = mid;
0142           D.weight = weight;
0143           vecmem::device_vector<detail::SeedData> seedArray(m_seedArrayView);
0144           seedArray.push_back(D);
0145         }
0146       }
0147     }
0148   }
0149 
0150  private:
0151   const uint32_t m_numTripletFilterThreads;
0152   vecmem::data::vector_view<uint32_t> m_sumBotMidView;
0153   const uint32_t m_firstMiddle;
0154   vecmem::data::vector_view<uint32_t> m_indMidBotCompView;
0155   vecmem::data::vector_view<uint32_t> m_indBotDupletView;
0156   vecmem::data::vector_view<uint32_t> m_sumBotTopCombView;
0157   vecmem::data::jagged_vector_view<uint32_t> m_midTopDupletView;
0158   vecmem::data::vector_view<detail::DeviceTriplet> m_curvImpactView;
0159   vecmem::data::vector_view<const detail::DeviceSpacePoint> m_topSPsView;
0160   vecmem::data::vector_view<const detail::DeviceSpacePoint> m_middleSPsView;
0161   vecmem::data::vector_view<const detail::DeviceSpacePoint> m_bottomSPsView;
0162   vecmem::data::vector_view<uint32_t> m_countTripletsView;
0163   vecmem::data::vector_view<detail::SeedData> m_seedArrayView;
0164   DeviceSeedFinderConfig m_config;
0165   DeviceExperimentCuts m_cuts;
0166 };  // struct TripletFilter
0167 }  // namespace Acts::Sycl::detail