Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ds-specific module id to avoid conflicts #6847

Merged
merged 15 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
@@ -243,7 +243,7 @@ def _start_of_forward_hook(module, *args):
self.module.register_forward_pre_hook(_start_of_forward_hook)

#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self._register_deepspeed_module(self.module)

# Add top module to stack trace
global FWD_MODULE_STACK
@@ -269,19 +269,19 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):

return persistent_params

def _register_hooks_recursively(self, module, count=[0]):
def _register_deepspeed_module(self, module, count=[0]):
my_count = count[0]
module.id = my_count
module.ds_id = my_count

#print(f"{module.__class__} : {module.id}")
#print(f"{module.__class__} : {module.ds_id}")

if z3_leaf_module(module):
for param in module.parameters():
param.ds_z3_leaf_module = module
else:
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
self._register_deepspeed_module(child, count=count)

@instrument_w_nvtx
def _pre_forward_module_hook(module, *args):
@@ -466,14 +466,16 @@ def pre_sub_module_forward_function(self, sub_module):

@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module)

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
@@ -488,13 +490,13 @@ def pre_sub_module_backward_function(self, sub_module):
def post_sub_module_backward_function(self, sub_module):
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

self.get_param_coordinator().release_sub_module(sub_module)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold):
24 changes: 12 additions & 12 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
@@ -175,17 +175,17 @@ def trace_prologue(self, sub_module: Module) -> None:
# sub_module must match expectation else invalidate trace cache
if len(self.__submodule_order) <= self.__step_id:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: "
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
f"cache has only {len(self.__submodule_order)} modules",
force=True)
self._invalidate_trace()
return

if sub_module != self.__submodule_order[self.__step_id]:
expected_module_id = self.__submodule_order[self.__step_id].id
expected_module_id = self.__submodule_order[self.__step_id].ds_id
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.id}",
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
force=True)
self._invalidate_trace()

@@ -199,7 +199,7 @@ def record_module(self, sub_module: Module) -> None:
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

self.__submodule_order.append(sub_module)
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
self.__step_id_module_fetched_for[sub_module.ds_id].append(self.__step_id)

def record_parameters(self, sub_module: Module) -> None:
if is_compiling():
@@ -208,7 +208,7 @@ def record_parameters(self, sub_module: Module) -> None:
if not self.is_record_trace():
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
step_id = self.__step_id_module_fetched_for[sub_module.ds_id].popleft()
for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id):
self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id))

@@ -228,7 +228,7 @@ def reset_step(self) -> None:

if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded submodule orders are identical across ranks
assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
assert_ints_same_as_other_ranks([m.ds_id for m in self.__submodule_order])

if self.is_record_trace():
# Successfully recorded a trace
@@ -241,7 +241,7 @@ def reset_step(self) -> None:
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = ZeRoTraceMode.COMPLETE
print_rank_0(
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}",
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.ds_id for m in self.__submodule_order]}",
force=False)
else:
# Enable trace recording for next forward/backward pass
@@ -284,7 +284,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
"""
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
f"{self.__step_id}: M{current_submodule.ds_id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
@@ -297,7 +297,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:

if fetch_numel > 0:
event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
self._dump_param_ids(event_name, current_submodule.id,
self._dump_param_ids(event_name, current_submodule.ds_id,
[p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
self.__profiler.start_event(event_name)
# kick off all gather for params in the immediately required submodule
@@ -314,7 +314,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
param.ds_active_sub_modules.add(current_submodule.ds_id)
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-wait: {param.ds_summary()}")
if param in self.__inflight_param_registry:
@@ -358,7 +358,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if discarded_from_prefetch_queue != params_not_already_fetched:
raise RuntimeError(
f"tracing error at step {self.__step_id}: \n"
f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
f"module id: {current_submodule.ds_id}, training: {current_submodule.training}\n"
f"expected the next {len(params_not_already_fetched)} parameters in the "
f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.")
@@ -425,7 +425,7 @@ def release_sub_module(self, submodule: Module) -> None:
empty_buffer = torch.empty(1, device=get_accelerator().current_device())

for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
param.ds_active_sub_modules.discard(submodule.ds_id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param, free_data)
if not free_data:
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ def unwrap_model_for_generation(model):
optimizer_offload = model.optimizer.parameter_offload
elif model.optimizer is not None:
optimizer_offload = model.optimizer
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
return


34 changes: 34 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
@@ -1673,3 +1673,37 @@ def test(self, prefetch_ratio, zero_stage=3):
with torch.no_grad():
for batch in data_loader:
loss = model(batch[0], batch[1])


# Avoid overwriting client module id
# https://github.com/microsoft/DeepSpeed/issues/6772
class TestZero3ClientModuleID(DistributedTest):
world_size = 2

def test_client_module_id(self):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
},
"zero_optimization": {
"stage": 3
},
}

class MyModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.id = 3 # ID arbitrary client usage, e.g. GPU placement
self.fc = Linear(128, 128)

def forward(self, x):
return self.fc(x)

model = MyModel()
pre_init_m_id = model.id
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
post_init_m_id = model.id
assert pre_init_m_id == post_init_m_id