Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-06 08:09:57

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2022 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #include <array>
0010 #include <vector>
0011 
0012 #include <boost/pending/disjoint_sets.hpp>
0013 
0014 namespace Acts::Ccl::internal {
0015 
0016 // Machinery for validating generic Cell/Cluster types at compile-time
0017 
0018 template <typename, std::size_t, typename T = void>
0019 struct cellTypeHasRequiredFunctions : std::false_type {};
0020 
0021 template <typename T>
0022 struct cellTypeHasRequiredFunctions<
0023     T, 2,
0024     std::void_t<decltype(getCellRow(std::declval<T>())),
0025                 decltype(getCellColumn(std::declval<T>())),
0026                 decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
0027 };
0028 
0029 template <typename T>
0030 struct cellTypeHasRequiredFunctions<
0031     T, 1,
0032     std::void_t<decltype(getCellColumn(std::declval<T>())),
0033                 decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
0034 };
0035 
0036 template <typename, typename, typename T = void>
0037 struct clusterTypeHasRequiredFunctions : std::false_type {};
0038 
0039 template <typename T, typename U>
0040 struct clusterTypeHasRequiredFunctions<
0041     T, U,
0042     std::void_t<decltype(clusterAddCell(std::declval<T>(), std::declval<U>()))>>
0043     : std::true_type {};
0044 
0045 template <std::size_t GridDim>
0046 constexpr void staticCheckGridDim() {
0047   static_assert(
0048       GridDim == 1 || GridDim == 2,
0049       "mergeClusters is only defined for grid dimensions of 1 or 2. ");
0050 }
0051 
0052 template <typename T, std::size_t GridDim>
0053 constexpr void staticCheckCellType() {
0054   constexpr bool hasFns = cellTypeHasRequiredFunctions<T, GridDim>();
0055   static_assert(hasFns,
0056                 "Cell type should have the following functions: "
0057                 "'int getCellRow(const Cell&)', "
0058                 "'int getCellColumn(const Cell&)', "
0059                 "'Label& getCellLabel(Cell&)'");
0060 }
0061 
0062 template <typename T, typename U>
0063 constexpr void staticCheckClusterType() {
0064   constexpr bool hasFns = clusterTypeHasRequiredFunctions<T, U>();
0065   static_assert(hasFns,
0066                 "Cluster type should have the following function: "
0067                 "'void clusterAddCell(Cluster&, const Cell&)'");
0068 }
0069 
0070 template <typename Cell, std::size_t GridDim>
0071 struct Compare {
0072   static_assert(GridDim != 1 && GridDim != 2,
0073                 "Only grid dimensions of 1 or 2 are supported");
0074 };
0075 
0076 // Comparator function object for cells, column-wise ordering
0077 // Specialization for 2-D grid
0078 template <typename Cell>
0079 struct Compare<Cell, 2> {
0080   bool operator()(const Cell& c0, const Cell& c1) const {
0081     int row0 = getCellRow(c0);
0082     int row1 = getCellRow(c1);
0083     int col0 = getCellColumn(c0);
0084     int col1 = getCellColumn(c1);
0085     return (col0 == col1) ? row0 < row1 : col0 < col1;
0086   }
0087 };
0088 
0089 // Specialization for 1-D grids
0090 template <typename Cell>
0091 struct Compare<Cell, 1> {
0092   bool operator()(const Cell& c0, const Cell& c1) const {
0093     int col0 = getCellColumn(c0);
0094     int col1 = getCellColumn(c1);
0095     return col0 < col1;
0096   }
0097 };
0098 
0099 // Simple wrapper around boost::disjoint_sets. In theory, could use
0100 // boost::vector_property_map and use boost::disjoint_sets without
0101 // wrapping, but it's way slower
0102 class DisjointSets {
0103  public:
0104   explicit DisjointSets(std::size_t initial_size = 128)
0105       : m_size(initial_size),
0106         m_rank(m_size),
0107         m_parent(m_size),
0108         m_ds(&m_rank[0], &m_parent[0]) {}
0109 
0110   Label makeSet() {
0111     // Empirically, m_size = 128 seems to be good default. If we
0112     // exceed this, take a performance hit and do the right thing.
0113     while (m_globalId >= m_size) {
0114       m_size *= 2;
0115       m_rank.resize(m_size);
0116       m_parent.resize(m_size);
0117       m_ds = boost::disjoint_sets<std::size_t*, std::size_t*>(&m_rank[0],
0118                                                               &m_parent[0]);
0119     }
0120     m_ds.make_set(m_globalId);
0121     return static_cast<Label>(m_globalId++);
0122   }
0123 
0124   void unionSet(std::size_t x, std::size_t y) { m_ds.union_set(x, y); }
0125   Label findSet(std::size_t x) { return static_cast<Label>(m_ds.find_set(x)); }
0126 
0127  private:
0128   std::size_t m_globalId = 1;
0129   std::size_t m_size;
0130   std::vector<std::size_t> m_rank;
0131   std::vector<std::size_t> m_parent;
0132   boost::disjoint_sets<std::size_t*, std::size_t*> m_ds;
0133 };
0134 
0135 template <std::size_t BufSize>
0136 struct ConnectionsBase {
0137   std::size_t nconn{0};
0138   std::array<Label, BufSize> buf;
0139   ConnectionsBase() { std::fill(buf.begin(), buf.end(), NO_LABEL); }
0140 };
0141 
0142 template <std::size_t GridDim>
0143 class Connections {};
0144 
0145 // On 1-D grid, cells have 1 backward neighbor
0146 template <>
0147 struct Connections<1> : public ConnectionsBase<1> {
0148   using ConnectionsBase::ConnectionsBase;
0149 };
0150 
0151 // On a 2-D grid, cells have 4 backward neighbors
0152 template <>
0153 struct Connections<2> : public ConnectionsBase<4> {
0154   using ConnectionsBase::ConnectionsBase;
0155 };
0156 
0157 // Cell collection logic
0158 template <typename Cell, typename Connect, std::size_t GridDim>
0159 Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
0160                                     std::vector<Cell>& set, Connect connect) {
0161   Connections<GridDim> seen;
0162   typename std::vector<Cell>::iterator it_2{it};
0163 
0164   while (it_2 != set.begin()) {
0165     it_2 = std::prev(it_2);
0166 
0167     ConnectResult cr = connect(*it, *it_2);
0168     if (cr == ConnectResult::eNoConnStop) {
0169       break;
0170     }
0171     if (cr == ConnectResult::eNoConn) {
0172       continue;
0173     }
0174     if (cr == ConnectResult::eConn) {
0175       seen.buf[seen.nconn] = getCellLabel(*it_2);
0176       seen.nconn += 1;
0177       if (seen.nconn == seen.buf.size()) {
0178         break;
0179       }
0180     }
0181   }
0182   return seen;
0183 }
0184 
0185 template <typename CellCollection, typename ClusterCollection>
0186 ClusterCollection mergeClustersImpl(CellCollection& cells) {
0187   using Cluster = typename ClusterCollection::value_type;
0188 
0189   if (cells.empty()) {
0190     return {};
0191   }
0192 
0193   // Accumulate clusters into the output collection
0194   ClusterCollection outv;
0195   Cluster cl;
0196   int lbl = getCellLabel(cells.front());
0197   for (auto& cell : cells) {
0198     if (getCellLabel(cell) != lbl) {
0199       // New cluster, save previous one
0200       outv.push_back(std::move(cl));
0201       cl = Cluster();
0202       lbl = getCellLabel(cell);
0203     }
0204     clusterAddCell(cl, cell);
0205   }
0206   // Get the last cluster as well
0207   outv.push_back(std::move(cl));
0208 
0209   return outv;
0210 }
0211 
0212 }  // namespace Acts::Ccl::internal
0213 
0214 namespace Acts::Ccl {
0215 
0216 template <typename Cell>
0217 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0218                                           const Cell& iter) const {
0219   int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
0220   int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0221   // Iteration is column-wise, so if too far in column, can
0222   // safely stop
0223   if (deltaCol > 1) {
0224     return ConnectResult::eNoConnStop;
0225   }
0226   // For same reason, if too far in row we know the pixel is not
0227   // connected, but need to keep iterating
0228   if (deltaRow > 1) {
0229     return ConnectResult::eNoConn;
0230   }
0231   // Decide whether or not cluster is connected based on 4- or
0232   // 8-connectivity
0233   if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
0234     return ConnectResult::eConn;
0235   }
0236   return ConnectResult::eNoConn;
0237 }
0238 
0239 template <typename Cell>
0240 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0241                                           const Cell& iter) const {
0242   int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0243   return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
0244 }
0245 
0246 template <std::size_t GridDim>
0247 void recordEquivalences(const internal::Connections<GridDim> seen,
0248                         internal::DisjointSets& ds) {
0249   // Sanity check: first element should always have
0250   // label if nconn > 0
0251   if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0252     throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0253   }
0254   for (std::size_t i = 1; i < seen.nconn; i++) {
0255     // Sanity check: since connection lookup is always backward
0256     // while iteration is forward, all connected cells found here
0257     // should have a label
0258     if (seen.buf[i] == NO_LABEL) {
0259       throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0260     }
0261     // Only record equivalence if needed
0262     if (seen.buf[0] != seen.buf[i]) {
0263       ds.unionSet(seen.buf[0], seen.buf[i]);
0264     }
0265   }
0266 }
0267 
0268 template <typename CellCollection, std::size_t GridDim, typename Connect>
0269 void labelClusters(CellCollection& cells, Connect connect) {
0270   using Cell = typename CellCollection::value_type;
0271   internal::staticCheckCellType<Cell, GridDim>();
0272 
0273   internal::DisjointSets ds{};
0274 
0275   // Sort cells by position to enable in-order scan
0276   std::sort(cells.begin(), cells.end(), internal::Compare<Cell, GridDim>());
0277 
0278   // First pass: Allocate labels and record equivalences
0279   for (auto it = cells.begin(); it != cells.end(); ++it) {
0280     const internal::Connections<GridDim> seen =
0281         internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
0282     if (seen.nconn == 0) {
0283       // Allocate new label
0284       getCellLabel(*it) = ds.makeSet();
0285     } else {
0286       recordEquivalences(seen, ds);
0287       // Set label for current cell
0288       getCellLabel(*it) = seen.buf[0];
0289     }
0290   }
0291 
0292   // Second pass: Merge labels based on recorded equivalences
0293   for (auto& cell : cells) {
0294     Label& lbl = getCellLabel(cell);
0295     lbl = ds.findSet(lbl);
0296   }
0297 }
0298 
0299 template <typename CellCollection, typename ClusterCollection,
0300           std::size_t GridDim = 2>
0301 ClusterCollection mergeClusters(CellCollection& cells) {
0302   using Cell = typename CellCollection::value_type;
0303   using Cluster = typename ClusterCollection::value_type;
0304   internal::staticCheckGridDim<GridDim>();
0305   internal::staticCheckCellType<Cell, GridDim>();
0306   internal::staticCheckClusterType<Cluster&, const Cell&>();
0307 
0308   if constexpr (GridDim > 1) {
0309     // Sort the cells by their cluster label, only needed if more than
0310     // one spatial dimension
0311     std::sort(cells.begin(), cells.end(), [](Cell& lhs, Cell& rhs) {
0312       return getCellLabel(lhs) < getCellLabel(rhs);
0313     });
0314   }
0315 
0316   return internal::mergeClustersImpl<CellCollection, ClusterCollection>(cells);
0317 }
0318 
0319 template <typename CellCollection, typename ClusterCollection,
0320           std::size_t GridDim, typename Connect>
0321 ClusterCollection createClusters(CellCollection& cells, Connect connect) {
0322   using Cell = typename CellCollection::value_type;
0323   using Cluster = typename ClusterCollection::value_type;
0324   internal::staticCheckCellType<Cell, GridDim>();
0325   internal::staticCheckClusterType<Cluster&, const Cell&>();
0326   labelClusters<CellCollection, GridDim, Connect>(cells, connect);
0327   return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
0328 }
0329 
0330 }  // namespace Acts::Ccl