Skip to content

Commit

Permalink
Refactored Louvain slightly, added Leiden based on Louvain with the a…
Browse files Browse the repository at this point in the history
…ddition of the refine partition step
  • Loading branch information
jwyles committed Aug 11, 2020
1 parent 18f5240 commit 05045b4
Show file tree
Hide file tree
Showing 12 changed files with 849 additions and 65 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ add_library(cugraph SHARED
src/community/spectral_clustering.cu
src/community/louvain.cpp
src/community/louvain_kernels.cu
src/community/leiden.cpp
src/community/leiden_kernels.cu
src/community/ktruss.cu
src/community/ECG.cu
src/community/triangles_counting.cu
Expand Down
28 changes: 28 additions & 0 deletions cpp/include/algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,34 @@ void louvain(GraphCSRView<vertex_t, edge_t, weight_t> const &graph,
int max_iter = 100,
weight_t resolution = weight_t{1});

/**
* @brief Leiden implementation
*
* Compute a clustering of the graph by maximizing modularity using the Leiden improvements
* to the Louvain method.
*
* @throws cugraph::logic_error when an error occurs.
*
* @tparam vertex_t Type of vertex identifiers.
* Supported value : int (signed, 32-bit)
* @tparam edge_t Type of edge identifiers.
* Supported value : int (signed, 32-bit)
* @tparam weight_t Type of edge weights. Supported values : float or double.
*
* @param[in] graph input graph object (CSR)
* @param[out] final_modularity modularity of the returned clustering
* @param[out] num_level number of levels of the returned clustering
* @param[out] clustering Pointer to device array where the clustering should be stored
* @param[in] max_iter (optional) maximum number of iterations to run (default 100)
*/
template <typename vertex_t, typename edge_t, typename weight_t>
void leiden(GraphCSRView<vertex_t, edge_t, weight_t> const &graph,
weight_t *final_modularity,
int *num_level,
vertex_t *leiden_parts,
int max_iter = 100,
weight_t resolution = weight_t{1});

/**
* @brief Computes the ecg clustering of the given graph.
*
Expand Down
52 changes: 52 additions & 0 deletions cpp/src/community/leiden.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <algorithms.hpp>
#include <graph.hpp>

#include <rmm/thrust_rmm_allocator.h>

#include <thrust/sequence.h>

#include <community/leiden_kernels.hpp>

#include "utilities/error.hpp"

namespace cugraph {

template <typename vertex_t, typename edge_t, typename weight_t>
void leiden(GraphCSRView<vertex_t, edge_t, weight_t> const &graph,
weight_t *final_modularity,
int *num_level,
vertex_t *leiden_parts,
int max_level,
weight_t resolution)
{
CUGRAPH_EXPECTS(graph.edge_data != nullptr, "API error, louvain expects a weighted graph");
CUGRAPH_EXPECTS(final_modularity != nullptr, "API error, final_modularity is null");
CUGRAPH_EXPECTS(num_level != nullptr, "API error, num_level is null");
CUGRAPH_EXPECTS(leiden_parts != nullptr, "API error, louvain_parts is null");

detail::leiden<vertex_t, edge_t, weight_t>(
graph, final_modularity, num_level, leiden_parts, max_level, resolution);
}

template void leiden(
GraphCSRView<int32_t, int32_t, float> const &, float *, int *, int32_t *, int, float);
template void leiden(
GraphCSRView<int32_t, int32_t, double> const &, double *, int *, int32_t *, int, double);

} // namespace cugraph
296 changes: 296 additions & 0 deletions cpp/src/community/leiden_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <graph.hpp>

#include <rmm/thrust_rmm_allocator.h>

#include <utilities/cuda_utils.cuh>
#include <utilities/graph_utils.cuh>
#include <community/louvain_kernels.hpp>

//#define TIMING

#ifdef TIMING
#include <utilities/high_res_timer.hpp>
#endif

#include <converters/COOtoCSR.cuh>

namespace cugraph {
namespace detail {

template <typename vertex_t, typename edge_t, typename weight_t>
weight_t update_clustering_by_delta_modularity_constrained(
weight_t total_edge_weight,
weight_t resolution,
GraphCSRView<vertex_t, edge_t, weight_t> const &graph,
rmm::device_vector<vertex_t> const &src_indices,
rmm::device_vector<weight_t> const &vertex_weights,
rmm::device_vector<weight_t> &cluster_weights,
rmm::device_vector<vertex_t> &cluster,
rmm::device_vector<vertex_t>& constraint,
cudaStream_t stream)
{
rmm::device_vector<vertex_t> next_cluster(cluster);
rmm::device_vector<weight_t> delta_Q(graph.number_of_edges);
rmm::device_vector<vertex_t> cluster_hash(graph.number_of_edges);
rmm::device_vector<weight_t> old_cluster_sum(graph.number_of_vertices);

vertex_t *d_cluster_hash = cluster_hash.data().get();
vertex_t *d_cluster = cluster.data().get();
weight_t const *d_vertex_weights = vertex_weights.data().get();
weight_t *d_cluster_weights = cluster_weights.data().get();
weight_t *d_delta_Q = delta_Q.data().get();
vertex_t* d_constraint = constraint.data().get();
vertex_t const* d_src_indices = src_indices.data().get();
vertex_t const* d_dst_indices = graph.indices;

weight_t new_Q = modularity<vertex_t, edge_t, weight_t>(
total_edge_weight, resolution, graph, cluster.data().get(), stream);

weight_t cur_Q = new_Q - 1;

// To avoid the potential of having two vertices swap clusters
// we will only allow vertices to move up (true) or down (false)
// during each iteration of the loop
bool up_down = true;

while (new_Q > (cur_Q + 0.0001)) {
cur_Q = new_Q;

compute_delta_modularity(total_edge_weight,
resolution,
graph,
src_indices,
vertex_weights,
cluster_weights,
cluster,
cluster_hash,
delta_Q,
old_cluster_sum,
stream);

// Filter out positive delta_Q values for nodes not in the same constraint group
thrust::for_each(rmm::exec_policy(stream)->on(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(graph.number_of_edges),
[d_src_indices, d_dst_indices, d_constraint, d_delta_Q]
__device__ (vertex_t i) {
vertex_t start_cluster = d_constraint[d_src_indices[i]];
vertex_t end_cluster = d_constraint[d_dst_indices[i]];
if (start_cluster != end_cluster)
d_delta_Q[i] = weight_t{0.0}; });

assign_nodes(graph,
delta_Q,
cluster_hash,
src_indices,
next_cluster,
vertex_weights,
cluster_weights,
up_down,
stream);

up_down = !up_down;

new_Q = modularity<vertex_t, edge_t, weight_t>(
total_edge_weight, resolution, graph, next_cluster.data().get(), stream);

if (new_Q > cur_Q) { thrust::copy(next_cluster.begin(), next_cluster.end(), cluster.begin()); }
}

return cur_Q;
}

template float update_clustering_by_delta_modularity_constrained(float,
float,
GraphCSRView<int32_t, int32_t, float> const &,
rmm::device_vector<int32_t> const &,
rmm::device_vector<float> const &,
rmm::device_vector<float> &,
rmm::device_vector<int32_t> &,
rmm::device_vector<int32_t> &,
cudaStream_t);

template double update_clustering_by_delta_modularity_constrained(double,
double,
GraphCSRView<int32_t, int32_t, double> const &,
rmm::device_vector<int32_t> const &,
rmm::device_vector<double> const &,
rmm::device_vector<double> &,
rmm::device_vector<int32_t> &,
rmm::device_vector<int32_t> &,
cudaStream_t);

template <typename vertex_t, typename edge_t, typename weight_t>
void leiden(GraphCSRView<vertex_t, edge_t, weight_t> const &graph,
weight_t *final_modularity,
int *num_level,
vertex_t *cluster_vec,
int max_level,
weight_t resolution,
cudaStream_t stream)
{
#ifdef TIMING
HighResTimer hr_timer;
#endif

*num_level = 0;

//
// Vectors to create a copy of the graph
//
rmm::device_vector<edge_t> offsets_v(graph.offsets, graph.offsets + graph.number_of_vertices + 1);
rmm::device_vector<vertex_t> indices_v(graph.indices, graph.indices + graph.number_of_edges);
rmm::device_vector<weight_t> weights_v(graph.edge_data, graph.edge_data + graph.number_of_edges);
rmm::device_vector<vertex_t> src_indices_v(graph.number_of_edges);

//
// Weights and clustering across iterations of algorithm
//
rmm::device_vector<weight_t> vertex_weights_v(graph.number_of_vertices);
rmm::device_vector<weight_t> cluster_weights_v(graph.number_of_vertices);
rmm::device_vector<vertex_t> cluster_v(graph.number_of_vertices);

//
// Temporaries used within kernels. Each iteration uses less
// of this memory
//
rmm::device_vector<vertex_t> tmp_arr_v(graph.number_of_vertices);
rmm::device_vector<vertex_t> cluster_inverse_v(graph.number_of_vertices);

weight_t total_edge_weight =
thrust::reduce(rmm::exec_policy(stream)->on(stream), weights_v.begin(), weights_v.end());
weight_t best_modularity = -1;

//
// Initialize every cluster to reference each vertex to itself
//
thrust::sequence(rmm::exec_policy(stream)->on(stream), cluster_v.begin(), cluster_v.end());
thrust::copy(cluster_v.begin(), cluster_v.end(), cluster_vec);

//
// Our copy of the graph. Each iteration of the outer loop will
// shrink this copy of the graph.
//
GraphCSRView<vertex_t, edge_t, weight_t> current_graph(offsets_v.data().get(),
indices_v.data().get(),
weights_v.data().get(),
graph.number_of_vertices,
graph.number_of_edges);

current_graph.get_source_indices(src_indices_v.data().get());

while (*num_level < max_level) {
//
// Sum the weights of all edges departing a vertex. This is
// loop invariant, so we'll compute it here.
//
// Cluster weights are equivalent to vertex weights with this initial
// graph
//
#ifdef TIMING
hr_timer.start("init");
#endif

cugraph::detail::compute_vertex_sums(current_graph, vertex_weights_v, stream);
thrust::copy(vertex_weights_v.begin(), vertex_weights_v.end(), cluster_weights_v.begin());

#ifdef TIMING
hr_timer.stop();

hr_timer.start("update_clustering");
#endif

weight_t new_Q = update_clustering_by_delta_modularity(total_edge_weight,
resolution,
current_graph,
src_indices_v,
vertex_weights_v,
cluster_weights_v,
cluster_v,
stream);

// After finding the initial unconstrained partition we use that partitioning as the constraint
// for the second round.
rmm::device_vector<vertex_t> constraint(graph.number_of_vertices);
thrust::copy(cluster_v.begin(), cluster_v.end(), constraint.begin());
thrust::sequence(rmm::exec_policy(stream)->on(stream), cluster_v.begin(), cluster_v.end());
new_Q = update_clustering_by_delta_modularity_constrained(total_edge_weight,
resolution,
current_graph,
src_indices_v,
vertex_weights_v,
cluster_weights_v,
cluster_v,
constraint,
stream);


#ifdef TIMING
hr_timer.stop();
#endif

if (new_Q <= best_modularity) { break; }

best_modularity = new_Q;

#ifdef TIMING
hr_timer.start("shrinking graph");
#endif

// renumber the clusters to the range 0..(num_clusters-1)
vertex_t num_clusters = renumber_clusters(
graph.number_of_vertices, cluster_v, tmp_arr_v, cluster_inverse_v, cluster_vec, stream);
cluster_weights_v.resize(num_clusters);

// shrink our graph to represent the graph of supervertices
generate_superverticies_graph(current_graph, src_indices_v, num_clusters, cluster_v, stream);

// assign each new vertex to its own cluster
thrust::sequence(rmm::exec_policy(stream)->on(stream), cluster_v.begin(), cluster_v.end());

#ifdef TIMING
hr_timer.stop();
#endif

(*num_level)++;
}

#ifdef TIMING
hr_timer.display(std::cout);
#endif

*final_modularity = best_modularity;
}

template void leiden(GraphCSRView<int32_t, int32_t, float> const &,
float *,
int *,
int32_t *,
int,
float,
cudaStream_t);
template void leiden(GraphCSRView<int32_t, int32_t, double> const &,
double *,
int *,
int32_t *,
int,
double,
cudaStream_t);

} // namespace detail
} // namespace cugraph
Loading

0 comments on commit 05045b4

Please sign in to comment.