Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ZBVZeroBubbleSchedule #817

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

H-Huang
Copy link
Member

@H-Huang H-Huang commented Feb 4, 2025

This is dependent on the changes in this pytorch stack: pytorch/pytorch#146217

Add support for running ZBVZeroBubbleSchedule and v-shaped CSV schedules in torchtitan

Fixes #774

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 4, 2025
@H-Huang H-Huang requested a review from wconstab February 4, 2025 19:01
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
targets = labels if has_last_stage else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kinda nit picking but i feel like if the stage object inside the schedule already knows that it is first or last, we can avoid having the logic in the training loop too.

otoh it seems nice to be explicit at the train.py layer on whether we are asking to compute loss or not.

thoughts?
@tianyu-l

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels nice when we only explicitly pass in meaningful targets/losses when we are not sure if they'll be properly accessed, so I'm OK with these if-else statements.

But how different is input_ids? Can we just unify everything into pp_schedule.step(input_ids, target=targets, losses=losses)
and pass input_ids = None when not has_first_stage?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't do input_ids=None right now since we have logic that automatically splits all *args into microbatches. For example if the user wants to do step(tensors, None) that would be split up into microbatches of (tensor1, None), (tensor2, None), ... We could update the splitting logic but not sure if it is worth it

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU CI failed, not sure if it is due to the reason I commented.

targets = labels if has_last_stage else None
losses = [] if has_last_stage else None
if has_first_stage:
pp_schedule.step(input_ids, target=targets, losses=losses)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if a schedule has has_last_stage = True and has_first_stage = False for the output layer -- will it miss the chance to feed in losses?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yeah, that was the issue. Updated it and will let the CI run again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ZBVZeroBubble error
4 participants