Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not align FID with provided checkpoint #69

Open
LiCHH opened this issue Jun 6, 2024 · 6 comments
Open

Can not align FID with provided checkpoint #69

LiCHH opened this issue Jun 6, 2024 · 6 comments

Comments

@LiCHH
Copy link

LiCHH commented Jun 6, 2024

Hello, I wrote a script based on the demo_sample.ipynb to generate 50,000 samples and tested them using OpenAI's FID evaluation toolkit. However, I found that the metrics did not align. Could you help me identify the problem?
I got
image
by using d20 checkpoint and
image
by using d30 checkpoint.
The script is as below:

################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
from tqdm import tqdm
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 20    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

############################# 2. Sample with classifier-free guidance

# set args
seed = 1 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg = 1.5 #@param {type:"slider", min:1, max:10, step:0.1}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = 25
for img_cls in tqdm(range(1000)):
    for i in range(50 // B):
        label_B = torch.tensor([img_cls] * 25, device=device)
        # B = len(class_labels)
        # label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
        with torch.inference_mode():
            with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
                recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.96, g_seed=seed, more_smooth=more_smooth)
            bchw = recon_B3HW.permute(0, 2, 3, 1).mul_(255).cpu().numpy()
        bchw = bchw.astype(np.uint8)
        for j in range(B):
            img = PImage.fromarray(bchw[j])
            img.save(f"./samples_d20/sample_{img_cls * 50 + i * B + j}.png")
@ma-xu
Copy link

ma-xu commented Jun 11, 2024

@keyu-tian Thanks for your great work! could you please help with this issue? Any insights could be helpful.

@Kumbong
Copy link

Kumbong commented Jul 3, 2024

+1 Thanks for the great work! wondering if you are willing to share your evaluation scripts ?

@ChenDRAG
Copy link

ChenDRAG commented Jul 4, 2024

@LiCHH @ma-xu @Kumbong . It seems the generation seed is replicate for each class.
recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.96, g_seed=seed, more_smooth=more_smooth)

gseed should be different . I would use

    for i in range(50 // B):
        label_B = torch.tensor([img_cls] * 25, device=device)
        # B = len(class_labels)
        # label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
        with torch.inference_mode():
            with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
                recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.96, g_seed=seed + i, more_smooth=more_smooth)

@Tom-zgt
Copy link

Tom-zgt commented Jul 17, 2024

I tried your suggestion, but the results got worse. The randomness of c2i has already been introduced by rng, so shouldn't it be unnecessary to change the random seed? Or could you share your evaluation code with me? @ChenDRAG @keyu-tian

@sen-ye
Copy link

sen-ye commented Aug 6, 2024

You can try to use B=50, since random seed is shared across different batches.

@adreamwu
Copy link

adreamwu commented Aug 15, 2024

You can try to use B=50, since random seed is shared across different batches.

Have you been able to reproduce the results in Table 1 of the paper? Could you share the inference script?

We use B=50 for each class and var_d16 for evaluation.

  • report
FID IS Pre Rec
3.60 257.5 0.85 0.48
  • reproduced
FID IS Pre Rec
3.49 281.71 0.85 0.50

The main issue is the IS. The results on other metrics are similar to those in the paper. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants