Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-06 08:11:12

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2020-2021 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 // SYCL plugin include(s).
0012 #include "Acts/Plugins/Sycl/Seeding/detail/Types.hpp"
0013 
0014 #include "../Utilities/Arrays.hpp"
0015 #include "SpacePointType.hpp"
0016 
0017 // SYCL include(s).
0018 #include <CL/sycl.hpp>
0019 
0020 // System include(s).
0021 #include <cassert>
0022 #include <cstdint>
0023 
0024 // VecMem include(s).
0025 #include "vecmem/containers/data/vector_view.hpp"
0026 #include "vecmem/containers/device_vector.hpp"
0027 
0028 namespace Acts::Sycl::detail {
0029 
0030 /// Functor performing a linear coordinate transformation on spacepoint pairs
0031 template <SpacePointType OtherSPType>
0032 class LinearTransform {
0033   // Sanity check(s).
0034   static_assert((OtherSPType == SpacePointType::Bottom) ||
0035                     (OtherSPType == SpacePointType::Top),
0036                 "Class must be instantiated with either "
0037                 "Acts::Sycl::detail::SpacePointType::Bottom or "
0038                 "Acts::Sycl::detail::SpacePointType::Top");
0039 
0040  public:
0041   /// Constructor with all the necessary arguments
0042   LinearTransform(
0043       vecmem::data::vector_view<const DeviceSpacePoint> middleSPs,
0044       vecmem::data::vector_view<const DeviceSpacePoint> otherSPs,
0045       vecmem::data::vector_view<uint32_t> middleIndexLUT,
0046       vecmem::data::vector_view<uint32_t> otherIndexLUT, uint32_t nEdges,
0047       vecmem::data::vector_view<detail::DeviceLinEqCircle> resultArray)
0048       : m_middleSPs(middleSPs),
0049         m_otherSPs(otherSPs),
0050         m_middleIndexLUT(middleIndexLUT),
0051         m_otherIndexLUT(otherIndexLUT),
0052         m_nEdges(nEdges),
0053         m_resultArray(resultArray) {}
0054 
0055   /// Operator performing the coordinate linear transformation
0056   void operator()(cl::sycl::nd_item<1> item) const {
0057     // Get the index to operate on.
0058     const auto idx = item.get_global_linear_id();
0059     if (idx >= m_nEdges) {
0060       return;
0061     }
0062 
0063     // Translate this one index into indices in the spacepoint arrays.
0064     // Note that using asserts with the CUDA backend of dpc++ is not working
0065     // quite correctly at the moment. :-( So these checks may need to be
0066     // disabled if you need to build for an NVidia backend in Debug mode.
0067     vecmem::device_vector<uint32_t> middleIndexLUT(m_middleIndexLUT);
0068     const uint32_t middleIndex = middleIndexLUT[idx];
0069     assert(middleIndex < m_middleSPs.size());
0070     (void)m_middleSPs.size();
0071     vecmem::device_vector<uint32_t> otherIndexLUT(m_otherIndexLUT);
0072     const uint32_t otherIndex = otherIndexLUT[idx];
0073     assert(otherIndex < m_otherSPs.size());
0074     (void)m_otherSPs.size();
0075 
0076     // Create a copy of the spacepoint objects for the current thread. On
0077     // dedicated GPUs this provides a better performance than accessing
0078     // variables one-by-one from global device memory.
0079     const vecmem::device_vector<const DeviceSpacePoint> middleSPs(m_middleSPs);
0080     const DeviceSpacePoint middleSP = middleSPs[middleIndex];
0081     const vecmem::device_vector<const DeviceSpacePoint> otherSPs(m_otherSPs);
0082     const DeviceSpacePoint otherSP = otherSPs[otherIndex];
0083 
0084     // Calculate some "helper variables" for the coordinate linear
0085     // transformation.
0086     const float cosPhiM = middleSP.x / middleSP.r;
0087     const float sinPhiM = middleSP.y / middleSP.r;
0088 
0089     const float deltaX = otherSP.x - middleSP.x;
0090     const float deltaY = otherSP.y - middleSP.y;
0091     const float deltaZ = otherSP.z - middleSP.z;
0092 
0093     const float x = deltaX * cosPhiM + deltaY * sinPhiM;
0094     const float y = deltaY * cosPhiM - deltaX * sinPhiM;
0095     const float iDeltaR2 = 1.f / (deltaX * deltaX + deltaY * deltaY);
0096 
0097     // Create the result object.
0098     DeviceLinEqCircle result;
0099     result.iDeltaR = cl::sycl::sqrt(iDeltaR2);
0100     result.cotTheta = deltaZ * result.iDeltaR;
0101     if constexpr (OtherSPType == SpacePointType::Bottom) {
0102       result.cotTheta = -(result.cotTheta);
0103     }
0104     result.zo = middleSP.z - middleSP.r * result.cotTheta;
0105     result.u = x * iDeltaR2;
0106     result.v = y * iDeltaR2;
0107     result.er =
0108         ((middleSP.varZ + otherSP.varZ) +
0109          (result.cotTheta * result.cotTheta) * (middleSP.varR + otherSP.varR)) *
0110         iDeltaR2;
0111 
0112     // Store the result in the result vector
0113     vecmem::device_vector<detail::DeviceLinEqCircle> resultArray(m_resultArray);
0114     resultArray[idx] = result;
0115     return;
0116   }
0117 
0118  private:
0119   /// Pointer to the middle spacepoints (in global device memory)
0120   vecmem::data::vector_view<const DeviceSpacePoint> m_middleSPs;
0121   /// Pointer to the "other" (bottom or top) spacepoints (in global device mem.)
0122   vecmem::data::vector_view<const DeviceSpacePoint> m_otherSPs;
0123 
0124   /// Look-Up Table from the iteration index to the middle spacepoint index
0125   vecmem::data::vector_view<uint32_t> m_middleIndexLUT;
0126   /// Loop-Up Table from the iteration index to the "other" spacepoint index
0127   vecmem::data::vector_view<uint32_t> m_otherIndexLUT;
0128 
0129   /// Total number of elements in the result array
0130   uint32_t m_nEdges;
0131 
0132   /// The result array in device global memory
0133   vecmem::data::vector_view<detail::DeviceLinEqCircle> m_resultArray;
0134 
0135 };  // class LinearTransform
0136 
0137 }  // namespace Acts::Sycl::detail