From 01ab719d8a19fbbe584744348e8aa717a848413f Mon Sep 17 00:00:00 2001 From: bwanglzu Date: Wed, 24 Apr 2024 17:06:33 +0200 Subject: [PATCH 1/3] feat: add longclip --- src/open_clip/loss.py | 22 +++++++++++++++------- src/open_clip/utils.py | 18 ++++++++++++++++++ src/training/train.py | 8 ++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 7454bee5e..72155b7da 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from torch.nn import functional as F +from typing import Optional try: import torch.distributed.nn @@ -16,6 +17,7 @@ except ImportError: hvd = None +from utils import PCA class GatherFeatures: @@ -39,7 +41,7 @@ def __init__( if use_horovod: assert hvd is not None, 'Please install horovod' - def __call__(self, features: torch.Tensor): + def __call__(self, features: torch.Tensor, pca_dim: Optional[int] = None): if self.use_horovod: if self.gather_with_grad: all_features = hvd.allgather(features) @@ -70,7 +72,8 @@ def __call__(self, features: torch.Tensor): gathered_features[self.rank] = features all_features = torch.cat(gathered_features, dim=0) - + if pca_dim: + all_features = PCA(all_features) return all_features @@ -84,6 +87,7 @@ def gather_features( rank=0, world_size=1, use_horovod=False, + pca_dim=None ): gather = GatherFeatures( local_loss=local_loss, @@ -91,10 +95,11 @@ def gather_features( rank=rank, world_size=world_size, use_horovod=use_horovod, + pca_dim=pca_dim, ) return ( - gather(image_features), - gather(text_features), + gather(image_features, pca_dim=pca_dim), # apply PCA on image faetures if set + gather(text_features, pca_dim=None), # never apply PCA on text features gather(teacher_features) if teacher_features else None ) @@ -134,7 +139,7 @@ def get_ground_truth(self, device, num_logits) -> torch.Tensor: labels = self.labels[device] return labels - def get_logits(self, image_features, text_features, logit_scale): + def get_logits(self, image_features, text_features, logit_scale, pca_dim: Optional[int] = None): if self.world_size > 1: all_image_features, all_text_features, _ = gather_features( image_features=image_features, @@ -144,6 +149,7 @@ def get_logits(self, image_features, text_features, logit_scale): rank=self.rank, world_size=self.world_size, use_horovod=self.use_horovod, + pca_dim=pca_dim ) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T @@ -154,15 +160,17 @@ def get_logits(self, image_features, text_features, logit_scale): ) logits_per_text = logits_per_image.T else: + if pca_dim: + image_features = PCA(image_features) logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T return logits_per_image, logits_per_text - def forward(self, image_features, text_features, logit_scale, output_dict=False): + def forward(self, image_features, text_features, logit_scale, output_dict=False, pca_dim = None): device = image_features.device logits_per_image, logits_per_text = self.get_logits( - image_features, text_features, logit_scale + image_features, text_features, logit_scale, pca_dim, ) labels = self.get_ground_truth(device, logits_per_image.shape[0]) diff --git a/src/open_clip/utils.py b/src/open_clip/utils.py index 29efeac78..5dc38154b 100644 --- a/src/open_clip/utils.py +++ b/src/open_clip/utils.py @@ -6,6 +6,24 @@ from torchvision.ops.misc import FrozenBatchNorm2d +def PCA(input_tensor, PCA_dim): + mean = torch.mean(input_tensor, dim=0) + X_centered = input_tensor - mean.unsqueeze(0) + X_centered = X_centered.float() + cov_matrix = torch.mm(X_centered.T, X_centered) + eigenvalues, eigenvectors = torch.linalg.eig(cov_matrix) + eigenvalues = eigenvalues.float() + eigenvectors = eigenvectors.float() + sorted_indices = torch.argsort(eigenvalues, descending=True) + eigenvectors = eigenvectors[:, sorted_indices] + principal_components = eigenvectors[:, :PCA_dim] + X_transformed = torch.mm(X_centered, principal_components) + X_reversed = torch.mm(X_transformed, principal_components.T) + X_reversed += mean + return X_reversed + + + def freeze_batch_norm_2d(module, module_match={}, name=''): """ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is diff --git a/src/training/train.py b/src/training/train.py index 814ad5edd..a1dc8203b 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -143,6 +143,11 @@ def train_one_epoch( images, texts = mm_batch images = images.to(device=device, dtype=input_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) + if args.longclip: + images_short = images.clone() + texts_short = [] + for text in texts: + texts_short.append(text.split(". ")[0]) if emb_batch: for batch in emb_batch: batch.to(device=device) @@ -200,6 +205,9 @@ def train_one_epoch( losses['embedding_loss'] = args.emb_loss_weight * embedding_loss + if args.longclip: + modelout_short = model(images_short, texts_short) + loss_short = loss(**modelout_short, output_dict=True, pca_dim=32) total_loss = sum(losses.values()) losses['loss'] = total_loss backward(total_loss, model, scaler=scaler, deepspeed=args.deepspeed) From 1c82bbc94b77578522d3a920cc1215e3941b5883 Mon Sep 17 00:00:00 2001 From: bwanglzu Date: Wed, 24 Apr 2024 17:09:34 +0200 Subject: [PATCH 2/3] feat: add args --- src/training/params.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/training/params.py b/src/training/params.py index 2d42e42b8..e37a81b86 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -689,6 +689,12 @@ def parse_args(args): help='The weighing factor for the embedding loss.', ) parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--longclip', + default=False, + action='store_true', + help='If set to true apply pca to image features and collect long & short loss', + ) args = parser.parse_args(args) From 6103e88ed47a430d1694278e7b6768904daebe7d Mon Sep 17 00:00:00 2001 From: bwanglzu Date: Wed, 24 Apr 2024 17:25:21 +0200 Subject: [PATCH 3/3] feat: add longclip loss --- src/open_clip/loss.py | 2 +- src/training/train.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 72155b7da..161d3be0d 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -17,7 +17,7 @@ except ImportError: hvd = None -from utils import PCA +from .utils import PCA class GatherFeatures: diff --git a/src/training/train.py b/src/training/train.py index a1dc8203b..cea2ddcf3 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -208,6 +208,7 @@ def train_one_epoch( if args.longclip: modelout_short = model(images_short, texts_short) loss_short = loss(**modelout_short, output_dict=True, pca_dim=32) + losses['short_loss'] = 0.1 * loss_short total_loss = sum(losses.values()) losses['loss'] = total_loss backward(total_loss, model, scaler=scaler, deepspeed=args.deepspeed)