Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added unload_lora_weights to StableDiffusionPipeline #11172

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dhawan98
Copy link

@dhawan98 dhawan98 commented Mar 30, 2025

-What does this PR do?
This PR adds a unload_lora_weights() method to StableDiffusionPipeline, enabling users to dynamically unload previously loaded LoRA adapters. This improves flexibility for scenarios where adapters need to be reset or replaced at runtime without reconstructing the pipeline.

-Motivation
Currently, StableDiffusionPipeline allows loading LoRA adapters, but there's no public method to unload them. This limitation makes it harder to manage LoRA state dynamically (e.g., in long-running applications or inference servers). This change enables a smoother adapter lifecycle.

-Before submitting the PR
Added tests under tests/loaders/test_lora_unload_reload.py that fail without the patch and pass with it.

-Verified that all existing tests pass: pytest tests/

-Code formatted with black and ruff

-Followed the module structure and naming conventions.

-Test Plan
You can verify the unload functionality via the test:

pytest tests/loaders/test_lora_unload_reload.py

It loads a LoRA adapter, unloads it using the new method, reloads it again, and checks that the adapter is active only when expected.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu
Copy link
Collaborator

cc @sayakpaul here

Copy link
Member

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhawan98 The code from test_lora_unload_reload appears to work as expected on main, and after review the changes result in the exact same functionality, the additional code is moot as mentioned in the review comments.

Can you elaborate on the use case that currently does not work?

from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float32
).to("cpu")

lora_repo = "latent-consistency/lcm-lora-sdv1-5"

# Load and activate LoRA
pipe.load_lora_weights(lora_repo)
adapters = pipe.get_list_adapters()
adapter_name = list(adapters["unet"])[0]

pipe.set_adapters([adapter_name], [1.0])

# Unload
pipe.unload_lora_weights()

# Reload
pipe.load_lora_weights(lora_repo)
adapters = pipe.get_list_adapters()
adapter_name = list(adapters["unet"])[0]

pipe.set_adapters([adapter_name], [0.8])

Comment on lines +523 to +524
if hasattr(model, "peft_config"):
model.peft_config.clear()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unload_lora deletes peft_config so this will always be a no-op.

if hasattr(self, "peft_config"):
del self.peft_config

Comment on lines +525 to +526
if hasattr(model, "active_adapter"):
model.active_adapter = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

active_adapter is a component of peft's BaseTunerLayer, not the model, this is also a no-op.

elif issubclass(model.__class__, PreTrainedModel):
_remove_text_encoder_monkey_patch(model)

torch.cuda.empty_cache()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work on non-cuda devices, we have device agnostic backend_empty_cache, however this addition seems unnecessary, what's the benefit of empty_cache here?

def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In testing, the basic unload–reload cycle works fine on the main branch when you use a single LoRA repeatedly. However, the issue we observed in production arises when you need to switch between different LoRA adapters within the same session or rapidly update weights on an adapter without restarting the pipeline.

Specifically, while the current implementation of unload_lora_weights() appears to delete the peft_config (and thus, the adapter state) in many cases, in a dynamic server environment the underlying transformer modules may still retain residual adapter registration information (e.g., within the internal state of the BaseTunerLayer). This can lead to subtle issues where:

If you attempt to load a new adapter (or reload the same adapter with updated weights) without a full pipeline reset, you might encounter errors such as "Adapter name ... already in use in the transformer" even though the adapter isn’t visible via get_list_adapters().

In scenarios with rapid switching or dynamic weight updates, any residual state in the transformer (which isn’t fully cleared by the current unload code) can cause unpredictable behavior or slight performance degradation.

This patch was intended to fully clear all traces of adapter registrations from the model’s transformer modules so that subsequent adapter loads truly start from a clean slate. This is particularly important when you want to ensure that no stale parameters interfere with new weights in a production environment where you may be swapping adapters on the fly.

If the changes result in “the exact same functionality” in simple tests, that may be because the basic unload–reload cycle doesn’t fully expose the issue in our unit tests. The additional code is aimed at more complex use cases—such as multiple adapters being loaded and unloaded sequentially in a long-running inference service—where even a small residual state can eventually lead to conflicts or degraded generation quality.

In summary, while the minimal test passes on the main branch, the use case we’re targeting is:

Dynamic production pipelines where LoRA adapters need to be swapped or updated without restarting the server.

Preventing state leakage in long-running sessions, ensuring that the transformer’s internal adapter registrations are fully reset between loads.

For more context, here's the GitHub repo where I use dynamic LoRA loading/unloading in a production-style API setup: https://github.com/dhawan98/LoRA-Server.
The patch allows to dynamically switch LoRAs mid-session without memory leaks or adapter conflicts, which is critical in this setup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants