Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

selfattention block: Remove the fc linear layer if it is not used #8325

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from

Conversation

johnzielke
Copy link
Contributor

Description

when include_fc = False, the nn.Linear layer is unused. This leads to errors and warning when training with the pytorch Distributed Data Parallel infrastructure, since the parameters for the nn.Linear layer will not have gradients attached.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@johnzielke johnzielke force-pushed the bugfix/attention-remove-unused-parameters branch 3 times, most recently from 1bb3ce5 to 892edc6 Compare February 4, 2025 15:09
@johnzielke johnzielke force-pushed the bugfix/attention-remove-unused-parameters branch from 892edc6 to 547ac94 Compare February 4, 2025 15:10
@ericspod
Copy link
Member

ericspod commented Feb 5, 2025

Thanks for the contribution! In itself I think it's fine however we have to check that this won't break old weights. We have this method load_old_state_dict for doing this with DiffusionModelUNet that might not work if out_proj doesn't have weight or bias components. There are other load_old_state_dict methods doing this for other networks that should be looked at.

We still want to maintain backwards compatibility with old stored weights at least for now, but we should discuss about when to deprecate these methods. CC @virginiafdez

@johnzielke
Copy link
Contributor Author

If that's the only concern, I could update that method to ignore that key in the appropriate cases

@ericspod
Copy link
Member

If that's the only concern, I could update that method to ignore that key in the appropriate cases

Please do have a look at how old state is loaded and see if there's any issues, otherwise yes we should be good here. I've updated your branch after we've done a lot of test refactoring, we should perhaps also include a test that checks the network does or does not have the fc layer when appropriate.

@johnzielke
Copy link
Contributor Author

I pushed the discussed changes. I wanted to test the load_old_state_dict as well, but it seems there is no test yet that loads the code without cross-attention. I did not dive all the way into where the old state dicts are stored etc. Is there an easy to use old state dict I could use in the test_compatibility_with_monai_generative() test?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants