diff --git a/docs/source/package_reference/helpers.md b/docs/source/package_reference/helpers.md index 70ede29476..83e129d6ea 100644 --- a/docs/source/package_reference/helpers.md +++ b/docs/source/package_reference/helpers.md @@ -14,4 +14,9 @@ A collection of helper functions for PEFT. ## Temporarily Rescaling Adapter Scale in LoraLayer Modules [[autodoc]] helpers.rescale_adapter_scale - - all \ No newline at end of file + - all + +## Context manager to disable input dtype casting in the `forward` method of LoRA layers + +[[autodoc]] helpers.disable_input_dtype_casting + - all diff --git a/src/peft/helpers.py b/src/peft/helpers.py index dc64486882..225bc5003f 100644 --- a/src/peft/helpers.py +++ b/src/peft/helpers.py @@ -18,8 +18,10 @@ from functools import update_wrapper from types import MethodType +from torch import nn + from .peft_model import PeftConfig, PeftModel -from .tuners.lora.layer import LoraLayer +from .tuners.lora import LoraLayer def update_forward_signature(model: PeftModel) -> None: @@ -209,3 +211,42 @@ def rescale_adapter_scale(model, multiplier): # restore original scaling values after exiting the context for module, scaling in original_scaling.items(): module.scaling = scaling + + +@contextmanager +def disable_input_dtype_casting(model: nn.Module, active: bool = True): + """ + Context manager disables input dtype casting to the dtype of the weight. + + Currently specifically works for LoRA. + + Parameters: + model (nn.Module): + The model containing PEFT modules whose input dtype casting is to be adjusted. + active (bool): + Whether the context manager is active (default) or inactive. + + """ + # Additional info: Normally, the dtype of the weight and input need to match, which is why the dtype is cast. + # However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in + # diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to + # disable it is given. + if not active: + yield + return + + original_values = {} + for name, module in model.named_modules(): + if not isinstance(module, LoraLayer): + continue + original_values[name] = module.cast_input_dtype_enabled + module.cast_input_dtype_enabled = False + + try: + yield + finally: + for name, module in model.named_modules(): + if not isinstance(module, LoraLayer): + continue + if name in original_values: + module.cast_input_dtype_enabled = original_values[name] diff --git a/src/peft/tuners/adalora/bnb.py b/src/peft/tuners/adalora/bnb.py index b8c32a815c..fef3d25e65 100644 --- a/src/peft/tuners/adalora/bnb.py +++ b/src/peft/tuners/adalora/bnb.py @@ -129,9 +129,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - compute_dtype = lora_A.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) + x = self._cast_input_dtype(x, lora_A.dtype) output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T if requires_conversion: diff --git a/src/peft/tuners/adalora/gptq.py b/src/peft/tuners/adalora/gptq.py index 910377c5db..bed1a0a7ca 100644 --- a/src/peft/tuners/adalora/gptq.py +++ b/src/peft/tuners/adalora/gptq.py @@ -55,8 +55,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - if x.dtype != torch.float32: - x = x.float() + x = self._cast_input_dtype(x, torch.float32) output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum # TODO: here, the dtype conversion is applied on the *whole expression*, diff --git a/src/peft/tuners/adalora/layer.py b/src/peft/tuners/adalora/layer.py index a3a1334d18..ea142c512f 100644 --- a/src/peft/tuners/adalora/layer.py +++ b/src/peft/tuners/adalora/layer.py @@ -180,7 +180,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: scaling = self.scaling[active_adapter] ranknum = self.ranknum[active_adapter] + 1e-5 - x = x.to(lora_A.dtype) + x = self._cast_input_dtype(x, lora_A.dtype) result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum return result diff --git a/src/peft/tuners/lora/aqlm.py b/src/peft/tuners/lora/aqlm.py index 1715e62646..81c7cdbb4e 100644 --- a/src/peft/tuners/lora/aqlm.py +++ b/src/peft/tuners/lora/aqlm.py @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor): requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) output = lora_B(lora_A(dropout(x))) if requires_conversion: diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index 86989d9000..61eb487ad6 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor): requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) output = lora_B(lora_A(dropout(x))) if requires_conversion: diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 3b364aac4a..36dcc72468 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -204,9 +204,7 @@ def _mixed_batch_forward( requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - compute_dtype = lora_A.weight.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear # layer output @@ -243,9 +241,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - compute_dtype = lora_A.weight.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: output = lora_B(lora_A(dropout(x))) * scaling @@ -470,7 +466,7 @@ def _mixed_batch_forward( requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear # layer output @@ -514,7 +510,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: output = lora_B(lora_A(dropout(x))) * scaling diff --git a/src/peft/tuners/lora/eetq.py b/src/peft/tuners/lora/eetq.py index d1eb79b878..e56d58d856 100644 --- a/src/peft/tuners/lora/eetq.py +++ b/src/peft/tuners/lora/eetq.py @@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor): requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) output = lora_B(lora_A(dropout(x))) if requires_conversion: diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index 0fb8cd49a3..4208265c4f 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -75,7 +75,7 @@ def forward(self, x: torch.Tensor): requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) output = lora_B(lora_A(dropout(x))) if requires_conversion: diff --git a/src/peft/tuners/lora/hqq.py b/src/peft/tuners/lora/hqq.py index d623f7ae2a..4f5a1e0f66 100644 --- a/src/peft/tuners/lora/hqq.py +++ b/src/peft/tuners/lora/hqq.py @@ -178,9 +178,7 @@ def _mixed_batch_forward( requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - compute_dtype = lora_A.weight.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear # layer output @@ -218,9 +216,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: requires_conversion = not torch.is_autocast_enabled() if requires_conversion: expected_dtype = result.dtype - compute_dtype = lora_A.weight.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: result = result + lora_B(lora_A(dropout(x))) * scaling diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 0067d30cd9..33a532a1d5 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -62,6 +62,8 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA self._caches: dict[str, Any] = {} self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload + # flag to enable/disable casting of input to weight dtype during forward call + self.cast_input_dtype_enabled: bool = True self.kwargs = kwargs base_layer = self.get_base_layer() @@ -492,6 +494,19 @@ def _mixed_batch_forward( return result + def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """ + Whether to cast the dtype of the input to the forward method. + + Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting + layer.cast_input_dtype=False, this can be disabled if necessary. + + Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager. + """ + if (not self.cast_input_dtype_enabled) or (x.dtype == dtype): + return x + return x.to(dtype=dtype) + # Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # and modified to work with PyTorch FSDP @@ -703,7 +718,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: result = result + lora_B(lora_A(dropout(x))) * scaling @@ -1268,7 +1283,7 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: result = result + lora_B(lora_A(dropout(x))) * scaling diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py index f47c565588..fabb3881d3 100644 --- a/src/peft/tuners/lora/tp_layer.py +++ b/src/peft/tuners/lora/tp_layer.py @@ -205,7 +205,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] - x = x.to(lora_A.weight.dtype) + x = self._cast_input_dtype(x, lora_A.weight.dtype) if not self.use_dora[active_adapter]: result = result + lora_B(lora_A(dropout(x))) * scaling diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 334e9e745b..ccf982c849 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -16,11 +16,13 @@ import pytest import torch from diffusers import StableDiffusionPipeline +from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model -from peft.helpers import check_if_peft_model, rescale_adapter_scale +from peft.helpers import check_if_peft_model, disable_input_dtype_casting, rescale_adapter_scale from peft.tuners.lora.layer import LoraLayer +from peft.utils import infer_device class TestCheckIsPeftModel: @@ -369,3 +371,102 @@ def test_merging_adapter(self, tokenizer): logits_merged_scaling = model(**inputs).logits assert torch.allclose(logits_merged_scaling, logits_unmerged_scaling, atol=1e-4, rtol=1e-4) + + +class TestDisableInputDtypeCasting: + """Test the context manager `disable_input_dtype_casting` that temporarily disables input dtype casting + in the model. + + The test works as follows: + + We create a simple MLP and convert it to a PeftModel. The model dtype is set to float16. Then a pre-foward hook is + added that casts the model parameters to float32. Moreover, a post-forward hook is added that casts the weights + back to float16. The input dtype is float32. + + Without the disable_input_dtype_casting context, what would happen is that PEFT detects that the input dtype is + float32 but the weight dtype is float16, so it casts the input to float16. Then the pre-forward hook casts the + weight to float32, which results in a RuntimeError. + + With the disable_input_dtype_casting context, the input dtype is left as float32 and there is no error. We also add + a hook to record the dtype of the result from the LoraLayer to ensure that it is indeed float32. + + """ + + device = infer_device() + dtype_record = [] + + @torch.no_grad() + def cast_params_to_fp32_pre_hook(self, module, input): + for param in module.parameters(recurse=False): + param.data = param.data.float() + return input + + @torch.no_grad() + def cast_params_to_fp16_hook(self, module, input, output): + for param in module.parameters(recurse=False): + param.data = param.data.half() + return output + + def record_dtype_hook(self, module, input, output): + self.dtype_record.append(output[0].dtype) + + @pytest.fixture + def inputs(self): + return torch.randn(4, 10, device=self.device, dtype=torch.float32) + + @pytest.fixture + def base_model(self): + class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = self.lin0(X) + X = self.lin1(X) + X = self.sm(X) + return X + + return MLP() + + @pytest.fixture + def model(self, base_model): + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + model = get_peft_model(base_model, config).to(device=self.device, dtype=torch.float16) + # Register hooks on the submodule that holds parameters + for module in model.modules(): + if sum(p.numel() for p in module.parameters()) > 0: + module.register_forward_pre_hook(self.cast_params_to_fp32_pre_hook) + module.register_forward_hook(self.cast_params_to_fp16_hook) + if isinstance(module, LoraLayer): + module.register_forward_hook(self.record_dtype_hook) + return model + + def test_disable_input_dtype_casting_active(self, model, inputs): + self.dtype_record.clear() + with disable_input_dtype_casting(model, active=True): + model(inputs) + assert self.dtype_record == [torch.float32] + + def test_no_disable_input_dtype_casting(self, model, inputs): + msg = r"expected m.*1 and m.*2 to have the same dtype" + with pytest.raises(RuntimeError, match=msg): + model(inputs) + + def test_disable_input_dtype_casting_inactive(self, model, inputs): + msg = r"expected m.*1 and m.*2 to have the same dtype" + with pytest.raises(RuntimeError, match=msg): + with disable_input_dtype_casting(model, active=False): + model(inputs) + + def test_disable_input_dtype_casting_inactive_after_existing_context(self, model, inputs): + # this is to ensure that when the context is left, we return to the previous behavior + with disable_input_dtype_casting(model, active=True): + model(inputs) + + # after the context exited, we're back to the error + msg = r"expected m.*1 and m.*2 to have the same dtype" + with pytest.raises(RuntimeError, match=msg): + model(inputs)