From cb2d841673c04daa44054623dc662fab6d25fbce Mon Sep 17 00:00:00 2001 From: Hugo Linsenmaier Date: Wed, 17 Feb 2021 11:23:30 -0800 Subject: [PATCH] TSP fix route return (#1412) The vertex list that was fed as input to TSP was of type `int64` which ended up being corrupted when passed down to the cpp layer as `vtx_ptr`. I updated the wrapper to cast the vertices to `int32`. In addition, I fixed the handling of nstart in the wrapper which was assuming vertex ids were starting at 1. Solves: https://github.com/rapidsai/cugraph/issues/1410 Authors: - Hugo Linsenmaier (@hlinsen) Approvers: - Brad Rees (@BradReesWork) URL: https://github.com/rapidsai/cugraph/pull/1412 --- python/cugraph/traversal/traveling_salesperson.py | 4 ++-- .../cugraph/traversal/traveling_salesperson_wrapper.pyx | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/cugraph/traversal/traveling_salesperson.py b/python/cugraph/traversal/traveling_salesperson.py index 80f9cd7441b..ae17555e4ea 100644 --- a/python/cugraph/traversal/traveling_salesperson.py +++ b/python/cugraph/traversal/traveling_salesperson.py @@ -20,7 +20,7 @@ def traveling_salesperson(pos_list, restarts=100000, beam_search=True, k=4, - nstart=1, + nstart=None, verbose=False, ): """ @@ -62,7 +62,7 @@ def traveling_salesperson(pos_list, null_check(pos_list['x']) null_check(pos_list['y']) - if not pos_list[pos_list['vertex'] == nstart].index: + if nstart is not None and not pos_list[pos_list['vertex'] == nstart].index: raise ValueError("nstart should be in vertex ids") route, cost = traveling_salesperson_wrapper.traveling_salesperson( diff --git a/python/cugraph/traversal/traveling_salesperson_wrapper.pyx b/python/cugraph/traversal/traveling_salesperson_wrapper.pyx index b728c3ff37d..5f87c42a638 100644 --- a/python/cugraph/traversal/traveling_salesperson_wrapper.pyx +++ b/python/cugraph/traversal/traveling_salesperson_wrapper.pyx @@ -31,7 +31,7 @@ def traveling_salesperson(pos_list, restarts=100000, beam_search=True, k=4, - nstart=1, + nstart=None, verbose=False, renumber=True, ): @@ -43,6 +43,7 @@ def traveling_salesperson(pos_list, cdef uintptr_t x_pos = NULL cdef uintptr_t y_pos = NULL + pos_list['vertex'] = pos_list['vertex'].astype(np.int32) pos_list['x'] = pos_list['x'].astype(np.float32) pos_list['y'] = pos_list['y'].astype(np.float32) x_pos = pos_list['x'].__cuda_array_interface__['data'][0] @@ -61,7 +62,10 @@ def traveling_salesperson(pos_list, cdef uintptr_t vtx_ptr = NULL vtx_ptr = pos_list['vertex'].__cuda_array_interface__['data'][0] - renumbered_nstart = pos_list[pos_list['vertex'] == nstart].index[0] + if nstart is None: + renumbered_nstart = 0 + else: + renumbered_nstart = pos_list[pos_list['vertex'] == nstart].index[0] final_cost = c_traveling_salesperson(handle_[0], vtx_ptr,