Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] Better image interpolation in training scripts #11206

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

asomoza
Copy link
Member

@asomoza asomoza commented Apr 4, 2025

What does this PR do?

As discussed here, this PR will add LANCZOS as a default interpolation mode for the image resizing in the training scripts and if the users prefers can choose BILINEAR.

I'll add this to the most popular recent ones and leave the rest to the community if they want to add them to other training scripts.

I'll do some training runs first to test if I can see the difference, but still, I already know that LANCZOS is better and that the models can pick subtle details that the human eye can't.

Fixes #6397

Who can review?

@bghira @linoytsaban @sayakpaul

Additional

Here's a little script if you want to test and try to see the difference:

import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms

from diffusers.utils import load_image


image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/training/circle.png")

# Define the resize transforms for 256x256
resize_bilinear_256 = transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR)
resize_lanczos_256 = transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.LANCZOS)

# Apply the 256x256 transforms
image_bilinear_256 = resize_bilinear_256(image)
image_lanczos_256 = resize_lanczos_256(image)
image_pil_lanczos_256 = image.resize((256, 256), Image.Resampling.LANCZOS)

# Define the resize transforms for 32x32
resize_bilinear_32 = transforms.Resize((32, 32), interpolation=transforms.InterpolationMode.BILINEAR)
resize_lanczos_32 = transforms.Resize((32, 32), interpolation=transforms.InterpolationMode.LANCZOS)

# Apply the 32x32 transforms
image_bilinear_32 = resize_bilinear_32(image)
image_lanczos_32 = resize_lanczos_32(image)
image_pil_lanczos_32 = image.resize((32, 32), Image.Resampling.LANCZOS)

# Display the results
plt.figure(figsize=(16, 12))

# First row - 256x256 resizing
plt.subplot(2, 4, 1)
plt.title("Original Image")
plt.imshow(image)
plt.axis("off")

plt.subplot(2, 4, 2)
plt.title("Bilinear (256x256)")
plt.imshow(image_bilinear_256)
plt.axis("off")

plt.subplot(2, 4, 3)
plt.title("torchvision Lanczos (256x256)")
plt.imshow(image_lanczos_256)
plt.axis("off")

plt.subplot(2, 4, 4)
plt.title("PIL Lanczos (256x256)")
plt.imshow(image_pil_lanczos_256)
plt.axis("off")

# Second row - 32x32 resizing
plt.subplot(2, 4, 5)
plt.title("Original Image")
plt.imshow(image)
plt.axis("off")

plt.subplot(2, 4, 6)
plt.title("Bilinear (32x32)")
plt.imshow(image_bilinear_32)
plt.axis("off")

plt.subplot(2, 4, 7)
plt.title("torchvision Lanczos (32x32)")
plt.imshow(image_lanczos_32)
plt.axis("off")

plt.subplot(2, 4, 8)
plt.title("PIL Lanczos (32x32)")
plt.imshow(image_pil_lanczos_32)
plt.axis("off")

plt.tight_layout()
plt.show()

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +801 to +804
if args.image_interpolation_mode == "bilinear":
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
else:
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.image_interpolation_mode == "bilinear":
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
else:
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS)
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode.")
train_resize = transforms.Resize(size, interpolation=interpolation)

Could we do something like this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, not sure about the other interpolation modes but I guess if the user uses another one it's because they know what they're doing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah since we have the default as lanczos I also think it's ok to assume the user knows what their doing if they use a mode other than lanczos/bilinear

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just cleaner code and preemptively solves the inevitable issue report: "Setting image_interpolation_mode other than bilinear uses lanczos"

@bghira
Copy link
Contributor

bghira commented Apr 4, 2025

thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a section about how training data resizing might affect the quality of the end models
5 participants