Skip to content

Commit

Permalink
fix intern2vl save bugs (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Dec 23, 2024
1 parent 167365f commit c99b2b7
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 63 deletions.
21 changes: 15 additions & 6 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,13 +866,22 @@ def copy_tokenizer(self, path):

@torch.no_grad()
def contiguous_params(self):
for name, param in self.model.model.named_parameters():
if not param.is_contiguous():
param.data = param.data.contiguous()
if self.model.mm_model is not None:
for name, param in self.model.mm_model.named_parameters():
if not param.is_contiguous():
param.data = param.data.contiguous()

for name, param in self.model.mm_model.named_buffers():
if not param.is_contiguous():
param.data = param.data.contiguous()
else:
for name, param in self.model.model.named_parameters():
if not param.is_contiguous():
param.data = param.data.contiguous()

for name, param in self.model.model.named_buffers():
if not param.is_contiguous():
param.data = param.data.contiguous()
for name, param in self.model.model.named_buffers():
if not param.is_contiguous():
param.data = param.data.contiguous()

@torch.no_grad()
def save_model(self, path):
Expand Down
124 changes: 67 additions & 57 deletions llmc/models/qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
import torch.nn as nn
from accelerate import Accelerator, DistributedType
from lmms_eval.api.model import lmms
from lmms_eval.models.qwen2_vl import Qwen2_VL
from loguru import logger
from transformers import AutoConfig, AutoProcessor, AutoTokenizer

Expand Down Expand Up @@ -211,63 +209,75 @@ def forward(self, *args, **kwargs):
return Catcher


@MODEL_REGISTRY
class Qwen2VLEval(Qwen2_VL):
def __init__(
self,
llmc_model,
pretrained: str = 'Qwen/Qwen2-VL-7B-Instruct',
device: Optional[str] = 'cuda',
device_map: Optional[str] = 'cuda',
batch_size: Optional[Union[int, str]] = 1,
use_cache=True,
use_flash_attention_2: Optional[bool] = False,
max_pixels: int = 12845056,
min_pixels: int = 3136,
max_num_frames: int = 32,
**kwargs,
) -> None:
lmms.__init__(self)
# Do not use kwargs for now
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'
try:
from lmms_eval.api.model import lmms
from lmms_eval.models.qwen2_vl import Qwen2_VL

accelerator = Accelerator()
if accelerator.num_processes > 1:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'
elif accelerator.num_processes == 1 and device_map == 'auto':
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'
@MODEL_REGISTRY
class Qwen2VLEval(Qwen2_VL):
def __init__(
self,
llmc_model,
pretrained: str = 'Qwen/Qwen2-VL-7B-Instruct',
device: Optional[str] = 'cuda',
device_map: Optional[str] = 'cuda',
batch_size: Optional[Union[int, str]] = 1,
use_cache=True,
use_flash_attention_2: Optional[bool] = False,
max_pixels: int = 12845056,
min_pixels: int = 3136,
max_num_frames: int = 32,
**kwargs,
) -> None:
lmms.__init__(self)
# Do not use kwargs for now
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'

accelerator = Accelerator()
if accelerator.num_processes > 1:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'
elif accelerator.num_processes == 1 and device_map == 'auto':
self._device = torch.device(device)
self.device_map = device_map
else:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'

self._model = llmc_model.eval().cuda()
self.processor = AutoProcessor.from_pretrained(pretrained,
max_pixels=max_pixels, min_pixels=min_pixels)
self.max_pixels = max_pixels
self.min_pixels = min_pixels
self.max_num_frames = max_num_frames
self._tokenizer = AutoTokenizer.from_pretrained(pretrained)
self._model = llmc_model.eval().cuda()
self.processor = AutoProcessor.from_pretrained(
pretrained,
max_pixels=max_pixels,
min_pixels=min_pixels
)
self.max_pixels = max_pixels
self.min_pixels = min_pixels
self.max_num_frames = max_num_frames
self._tokenizer = AutoTokenizer.from_pretrained(pretrained)

self._config = self.model.config
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache
self._config = self.model.config
self.batch_size_per_gpu = int(batch_size)
self.use_cache = use_cache

if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], 'Unsupported distributed type provided. Only DDP and FSDP are supported.'
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], 'Unsupported distributed type provided. Only DDP and FSDP are supported.'
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
else:
self._rank = 0
self._word_size = 1
self._rank = 0
self._word_size = 1
except Exception:
logger.warning(
'Can not import lmms_eval. '
'If you need it, please upgrade transformers.'
)

0 comments on commit c99b2b7

Please sign in to comment.