Skip to content

Commit

Permalink
add a sync and nccl abort to solve issues with empty partitions
Browse files Browse the repository at this point in the history
Signed-off-by: Erik Ordentlich <[email protected]>
  • Loading branch information
eordentlich committed Oct 7, 2023
1 parent 0c6bf0e commit 2e609ac
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
8 changes: 7 additions & 1 deletion python/src/spark_rapids_ml/common/cuml_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ def __exit__(self, *args: Any) -> None:
if not self.enable:
return
assert self._nccl_comm is not None
self._nccl_comm.destroy()

# if no exception cleanup nicely, otherwise abort
if not args[0]:
self._nccl_comm.destroy()
else:
self._nccl_comm.abort()

del self._nccl_comm

del self._handle
Expand Down
5 changes: 4 additions & 1 deletion python/src/spark_rapids_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,12 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:

if len(sizes) == 0 or all(sz == 0 for sz in sizes):
raise RuntimeError(
"A python worker received no data. Please increase amount of data or use fewer workers."
"A python worker received no data. Please ensure no empty partitions by increasing amount of data, using fewer workers, or repartitioning."
)

# not syncing here can result in hangs and unkilled python workers if error in data loading (e.g. empty partition)
context.barrier()

params[param_alias.handle] = cc.handle
params[param_alias.part_sizes] = sizes
params[param_alias.num_cols] = dimension
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def test_nearest_neighbors(
random_state=0,
) # make_blobs creates a random dataset of isotropic gaussian blobs.

# set average norm to be 1 to allow comparisons with default error thresholds
# set average norm sq to be 1 to allow comparisons with default error thresholds
# below
root_ave_norm_sq = np.sqrt(np.average(np.linalg.norm(X, ord=2, axis=1) ** 2))
X = X / root_ave_norm_sq
Expand Down

0 comments on commit 2e609ac

Please sign in to comment.