diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index 2364c4911da4f..ade08d9a34824 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -548,7 +548,7 @@ def return_actor(self, actor: ray.actor.ActorHandle): self._num_tasks_in_flight[actor] -= 1 if self._should_kill_idle_actors and self._num_tasks_in_flight[actor] == 0: - self._kill_running_actor(actor) + self._remove_actor(actor) def get_pending_actor_refs(self) -> List[ray.ObjectRef]: return list(self._pending_actors.keys()) @@ -585,7 +585,7 @@ def kill_inactive_actor(self) -> bool: def _maybe_kill_pending_actor(self) -> bool: if self._pending_actors: # At least one pending actor, so kill first one. - self._kill_pending_actor(next(iter(self._pending_actors.keys()))) + self._remove_actor(next(iter(self._pending_actors.keys()))) return True # No pending actors, so indicate to the caller that no actors were killed. return False @@ -594,7 +594,7 @@ def _maybe_kill_idle_actor(self) -> bool: for actor, tasks_in_flight in self._num_tasks_in_flight.items(): if tasks_in_flight == 0: # At least one idle actor, so kill first one found. - self._kill_running_actor(actor) + self._remove_actor(actor) return True # No idle actors, so indicate to the caller that no actors were killed. return False @@ -621,7 +621,7 @@ def kill_all_actors(self): def _kill_all_pending_actors(self): pending_actor_refs = list(self._pending_actors.keys()) for ref in pending_actor_refs: - self._kill_pending_actor(ref) + self._remove_actor(ref) def _kill_all_idle_actors(self): idle_actors = [ @@ -630,23 +630,26 @@ def _kill_all_idle_actors(self): if tasks_in_flight == 0 ] for actor in idle_actors: - self._kill_running_actor(actor) + self._remove_actor(actor) self._should_kill_idle_actors = True def _kill_all_running_actors(self): actors = list(self._num_tasks_in_flight.keys()) for actor in actors: - self._kill_running_actor(actor) - - def _kill_running_actor(self, actor: ray.actor.ActorHandle): - """Kill the provided actor and remove it from the pool.""" - ray.kill(actor) - del self._num_tasks_in_flight[actor] - - def _kill_pending_actor(self, ready_ref: ray.ObjectRef): - """Kill the provided pending actor and remove it from the pool.""" - actor = self._pending_actors.pop(ready_ref) - ray.kill(actor) + self._remove_actor(actor) + + def _remove_actor(self, actor: ray.actor.ActorHandle): + """Remove the given actor from the pool.""" + # NOTE: we remove references to the actor and let ref counting + # garbage collect the actor, instead of using ray.kill. + # Because otherwise the actor cannot be restarted upon lineage reconstruction. + for state_dict in [ + self._num_tasks_in_flight, + self._actor_locations, + self._pending_actors, + ]: + if actor in state_dict: + del state_dict[actor] def _get_location(self, bundle: RefBundle) -> Optional[NodeIdStr]: """Ask Ray for the node id of the given bundle. diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 6406e742438d2..9e01ab3b83c41 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -13,9 +13,11 @@ import pytest import ray +from ray._private.test_utils import wait_for_condition from ray.data._internal.execution.interfaces.ref_bundle import ( _ref_bundles_iterator_to_block_refs_list, ) +from ray.data._internal.execution.operators.actor_pool_map_operator import _MapWorker from ray.data.context import DataContext from ray.data.exceptions import UserCodeException from ray.data.tests.conftest import * # noqa @@ -76,6 +78,19 @@ def test_basic_actors(shutdown_only): concurrency=(8, 4), ) + # Make sure all actors are dead after dataset execution finishes. + def _all_actors_dead(): + actor_table = ray.state.actors() + actors = { + id: actor_info + for actor_info in actor_table.values() + if actor_info["ActorClassName"] == _MapWorker.__name__ + } + assert len(actors) > 0 + return all(actor_info["State"] == "DEAD" for actor_info in actors.values()) + + wait_for_condition(_all_actors_dead) + def test_callable_classes(shutdown_only): ray.init(num_cpus=2)