Skip to content

Commit

Permalink
Added sts mean and std
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Ding committed Jul 18, 2024
1 parent 95da56b commit 018e339
Show file tree
Hide file tree
Showing 5 changed files with 439 additions and 63 deletions.
110 changes: 110 additions & 0 deletions sd_scripts/library/autostats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torch.nn.functional as F
import os
from safetensors import safe_open
import numpy as np

from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)

standard_normal_distribution = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
def smooth(probs, step_size=0.5):
kernel = standard_normal_distribution.log_prob(torch.arange(-torch.pi, torch.pi, step_size) ).exp().to(probs.device)
smoothed = F.conv1d(probs[None, None, :].float(), kernel[None, None, :].float(), padding="same").reshape(-1)
return smoothed / smoothed.sum()

def kerras_timesteps(n, sigma_min=0.001, sigma_max=10.0):
alpha_min = torch.arctan(torch.tensor(sigma_min))
alpha_max = torch.arctan(torch.tensor(sigma_max))
step_indices = torch.arange(n)
sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
return sigmas

# cribbed from A111
def read_metadata_from_safetensors(filename):
import json

with open(filename, mode="rb") as file:
metadata_len = file.read(8)
metadata_len = int.from_bytes(metadata_len, "little")
json_start = file.read(2)

assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"

res = {}
try:
json_data = json_start + file.read(metadata_len-2)
json_obj = json.loads(json_data)
for k, v in json_obj.get("__metadata__", {}).items():
res[k] = v
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
except Exception:
pass
except Exception:
logger.error(f"Error reading metadata from file: {filename}", exc_info=True)

return res

def interp_forward(t, timesteps):
p = t.permute(1, 0).float().cpu().numpy() # Switch to channel-first and flip the order from first-denoised to first-noised
rev_ts = torch.tensor(timesteps).tolist() # Reverse the timesteps from denoising order to noising order
xs = np.arange(0, 1000)
t = torch.stack([torch.tensor(list(np.interp(xs, rev_ts, p[i]))) for i in range(0, 4)])
return t.permute(1, 0).to(t.device)

def load_model_noise_stats(args):
if args.autostats is None or not os.path.exists(args.autostats):
return None, None
with safe_open(args.autostats, framework="pt") as f:
observations = f.get_tensor("observations")
timesteps = f.get_tensor("timesteps")
return transform_observations(observations, timesteps)

def transform_observations(observations, timesteps):
# shape is [timestep, sample, channels, h, w]
# we average on sample, h, w so that we get stats for [timestep, channel]

means = observations.mean(dim=(1, 3, 4))
stds = observations.std(dim=(1, 3, 4))
return interp_forward(means, timesteps), interp_forward(stds, timesteps)

def autostats(args, generator):
timestep_probs = torch.ones(1000)
std_target_by_ts = mean_target_by_ts = scaled_std_target_by_ts = scaled_mean_target_by_ts = None

mean_target_by_ts, std_target_by_ts = load_model_noise_stats(args)
if mean_target_by_ts is None:
generator()
mean_target_by_ts, std_target_by_ts = load_model_noise_stats(args)

if mean_target_by_ts is None:
raise ValueError("Could not load noise stats from model")

std_target_by_ts = std_target_by_ts.view(-1, 4, 1, 1)
mean_target_by_ts = mean_target_by_ts.view(-1, 4, 1, 1)

std_weighting = (std_target_by_ts - 1).abs()
std_weighting = std_weighting / std_weighting.max(dim=0).values

mean_weighting = mean_target_by_ts.abs()
mean_weighting = mean_weighting / mean_weighting.max(dim=0).values

effect_scale = args.autostats_true_noise_weight
scaled_std_target_by_ts = (std_target_by_ts - 1.0) * effect_scale[0] + 1.0
scaled_mean_target_by_ts = (mean_target_by_ts * effect_scale[1])

if args.autostats_timestep_weighting:
timestep_probs = (std_target_by_ts - 1).abs().mean(dim=1).reshape(-1)
timestep_probs[:15] = timestep_probs[15]
timestep_probs = smooth(timestep_probs)

timestep_probs = timestep_probs / timestep_probs.sum()

print("std", scaled_std_target_by_ts.view(-1, 4))
print("mean", scaled_mean_target_by_ts.view(-1, 4))

return std_target_by_ts, mean_target_by_ts, scaled_std_target_by_ts, scaled_mean_target_by_ts, timestep_probs
13 changes: 9 additions & 4 deletions sd_scripts/library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,20 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def apply_masked_loss(loss, batch):
def get_mask(batch, latents):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = batch["conditioning_images"].to(dtype=latents.dtype)[:, 0].unsqueeze(1) # use R channel

# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = torch.nn.functional.interpolate(mask_image, size=latents.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
return mask_image


def apply_masked_loss(loss, batch):
mask_image = get_mask(batch, loss)
loss = loss * mask_image
return loss, mask_image
return loss


## Custom loss function for weighing a character, and specifical details of the character, differently from the background
Expand Down
Loading

0 comments on commit 018e339

Please sign in to comment.