Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-05 08:10:17

0001 // This file is part of the Acts project.
0002 //
0003 // Copyright (C) 2022 CERN for the benefit of the Acts project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
0008 
0009 #pragma once
0010 
0011 #include <Acts/Utilities/Logger.hpp>
0012 
0013 #include <cerrno>
0014 #include <cstring>
0015 #include <fstream>
0016 #include <iostream>
0017 #include <string>
0018 #include <tuple>
0019 #include <vector>
0020 
0021 #include <boost/range/combine.hpp>
0022 #include <cugraph/algorithms.hpp>
0023 #include <cugraph/graph.hpp>
0024 #include <cugraph/graph_functions.hpp>
0025 #include <cugraph/graph_view.hpp>
0026 #include <cugraph/partition_manager.hpp>
0027 #include <cugraph/utilities/error.hpp>
0028 #include <raft/cudart_utils.h>
0029 #include <raft/handle.hpp>
0030 
0031 #ifndef CUDA_RT_CALL
0032 #define CUDA_RT_CALL(call)                                                    \
0033   {                                                                           \
0034     cudaError_t cudaStatus = call;                                            \
0035     if (cudaSuccess != cudaStatus) {                                          \
0036       fprintf(stderr,                                                         \
0037               "ERROR: CUDA RT call \"%s\" in line %d of file %s failed with " \
0038               "%s (%d).\n",                                                   \
0039               #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus),      \
0040               cudaStatus);                                                    \
0041     }                                                                         \
0042   }
0043 #endif  // CUDA_RT_CALL
0044 
0045 template <typename vertex_t, typename edge_t, typename weight_t>
0046 __global__ void weaklyConnectedComponents(std::vector<vertex_t>& rowIndices,
0047                                           std::vector<vertex_t>& colIndices,
0048                                           std::vector<weight_t>& edgeWeights,
0049                                           std::vector<vertex_t>& trackLabels,
0050                                           const Acts::Logger& logger) {
0051   cudaStream_t stream;
0052   CUDA_RT_CALL(cudaStreamCreate(&stream));
0053 
0054   ACTS_VERBOSE("Weakly components Start");
0055   ACTS_VERBOSE("edge size: " << rowIndices.size() << " " << colIndices.size());
0056   raft::handle_t handle{stream};
0057 
0058   cugraph::graph_t<vertex_t, edge_t, weight_t, false, false> graph(handle);
0059 
0060   // learn from matrix_market_file_utilities.cu
0061   vertex_t maxVertexID_row =
0062       *std::max_element(rowIndices.begin(), rowIndices.end());
0063   vertex_t maxVertexID_col =
0064       *std::max_element(colIndices.begin(), colIndices.end());
0065   vertex_t maxVertex = std::max(maxVertexID_row, maxVertexID_col);
0066 
0067   vertex_t number_of_vertices = maxVertex;
0068   rmm::device_uvector<vertex_t> d_vertices(number_of_vertices,
0069                                            handle.get_stream());
0070   std::vector<vertex_t> vertex_idx(number_of_vertices);
0071   for (vertex_t idx = 0; idx < number_of_vertices; idx++) {
0072     vertex_idx[idx] = idx;
0073   }
0074 
0075   rmm::device_uvector<vertex_t> src_v(rowIndices.size(), handle.get_stream());
0076   rmm::device_uvector<vertex_t> dst_v(colIndices.size(), handle.get_stream());
0077   rmm::device_uvector<weight_t> weights_v(edgeWeights.size(),
0078                                           handle.get_stream());
0079 
0080   raft::update_device(src_v.data(), rowIndices.data(), rowIndices.size(),
0081                       handle.get_stream());
0082   raft::update_device(dst_v.data(), colIndices.data(), colIndices.size(),
0083                       handle.get_stream());
0084   raft::update_device(weights_v.data(), edgeWeights.data(), edgeWeights.size(),
0085                       handle.get_stream());
0086   raft::update_device(d_vertices.data(), vertex_idx.data(), vertex_idx.size(),
0087                       handle.get_stream());
0088 
0089   std::tie(graph, std::ignore) =
0090       cugraph::create_graph_from_edgelist<vertex_t, edge_t, weight_t, false,
0091                                           false>(
0092           handle, std::move(d_vertices), std::move(src_v), std::move(dst_v),
0093           std::move(weights_v), cugraph::graph_properties_t{true, false},
0094           false);
0095 
0096   auto graph_view = graph.view();
0097   CUDA_TRY(cudaDeviceSynchronize());  // for consistent performance measurement
0098 
0099   rmm::device_uvector<vertex_t> d_components(
0100       graph_view.get_number_of_vertices(), handle.get_stream());
0101 
0102   ACTS_VERBOSE("2back from construct_graph");
0103   cugraph::weakly_connected_components(handle, graph_view, d_components.data());
0104 
0105   ACTS_VERBOSE("number of components: " << d_components.size());
0106   raft::update_host(trackLabels.data(), d_components.data(),
0107                     d_components.size(), handle.get_stream());
0108 }