-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
selfattention block: Remove the fc linear layer if it is not used #8325
base: dev
Are you sure you want to change the base?
selfattention block: Remove the fc linear layer if it is not used #8325
Conversation
1bb3ce5
to
892edc6
Compare
Signed-off-by: John Zielke <[email protected]>
892edc6
to
547ac94
Compare
Thanks for the contribution! In itself I think it's fine however we have to check that this won't break old weights. We have this method load_old_state_dict for doing this with DiffusionModelUNet that might not work if We still want to maintain backwards compatibility with old stored weights at least for now, but we should discuss about when to deprecate these methods. CC @virginiafdez |
If that's the only concern, I could update that method to ignore that key in the appropriate cases |
Please do have a look at how old state is loaded and see if there's any issues, otherwise yes we should be good here. I've updated your branch after we've done a lot of test refactoring, we should perhaps also include a test that checks the network does or does not have the fc layer when appropriate. |
Signed-off-by: John Zielke <[email protected]>
I pushed the discussed changes. I wanted to test the load_old_state_dict as well, but it seems there is no test yet that loads the code without cross-attention. I did not dive all the way into where the old state dicts are stored etc. Is there an easy to use old state dict I could use in the |
Description
when include_fc = False, the nn.Linear layer is unused. This leads to errors and warning when training with the pytorch Distributed Data Parallel infrastructure, since the parameters for the nn.Linear layer will not have gradients attached.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.