Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparsegpt #374

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions lit_llama/sparsification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# This adapts SparseGPT process: https://github.com/IST-DASLab/sparsegpt
# E. Frantar et al SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot, https://arxiv.org/abs/2301.00774
# portions copyright by the authors licensed under the Apache License 2.0


import torch
import os
from contextlib import contextmanager
import warnings
import math


class SparseGPT:

def __init__(
self,
linear_module,
sparsity,
prunen=0,
prunem=0,
blocksize=128,
percdamp=0.01,

):
assert isinstance(linear_module, torch.nn.Linear)

self.linear_module = linear_module
self.dev = self.linear_module.weight.device
self.rows = linear_module.weight.shape[0]
self.columns = linear_module.weight.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.blocksize = blocksize
self.sparsity = sparsity
self.percdamp = percdamp
self.prunen = prunen
self.prunem = prunem

def collect_input_stats(self, _1, inp, _2):
inp = inp[0].detach()
self.last_inp = inp
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
inp = math.sqrt(2 / self.nsamples) * inp.float()
self.H += inp.matmul(inp.t())

def sparsify(self):

W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True)
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)

damp = self.percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

mask = None

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]


if self.prunen == 0:
if mask is not None:
mask1 = mask[:, i1:i2]
else:
tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * self.sparsity)]
mask1 = tmp <= thresh
else:
mask1 = torch.zeros_like(W1) == 1

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if self.prunen != 0 and i % self.prunem == 0:
tmp = W1[:, i:(i + self.prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + self.prunem)].reshape((1, -1))) ** 2
mask1.scatter_(1, i + torch.topk(tmp, self.prunen, dim=1, largest=False)[1], True)

q = w.clone()
q[mask1[:, i]] = 0


Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1

Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

pruned_weights = Q.reshape(self.linear_module.weight.shape).to(
self.linear_module.weight.data.dtype
)

# set the linear module weights to pruned weights
self.linear_module.weight.data = pruned_weights
error = torch.sum(Losses).item()
return error

257 changes: 257 additions & 0 deletions sparsify/sparsegpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# This adapts SparseGPT process: https://github.com/IST-DASLab/sparsegpt
# E. Frantar et al SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot, https://arxiv.org/abs/2301.00774
# portions copyright by the authors licensed under the Apache License 2.0


import gc
import sys
import time
from pathlib import Path
from typing import Optional

import torch
from datasets import load_dataset

wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.sparsification import SparseGPT

from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup


def get_sample_data():
traindata = load_dataset(
"allenai/c4",
"allenai--c4",
data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
split="train",
)
# heuristic for the data size?
txt = "\n".join(
traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist()
)
return txt

@torch.no_grad()
def llama_blockwise_sparsification(
model,
sample_inputs,
working_device,
*,
sparsity,
prunen=0,
prunem=0,

):

print('Getting Inputs for the first block')
model.transformer.wte.to(working_device)
sample_inputs = sample_inputs.to(working_device)
inps = model.transformer.wte(sample_inputs)
model.transformer.wte.to("cpu")
torch.cuda.empty_cache()

rope_cache = model.build_rope_cache(sample_inputs)
mask_cache = model.build_mask_cache(sample_inputs)

print('Starting to sparsify block')
outs = torch.zeros_like(inps)


submodules_to_process = [
"attn.c_attn",
"attn.c_proj",
"mlp.c_fc1",
"mlp.c_fc2",
"mlp.c_proj",
]


for i, block in enumerate(model.transformer.h):

block.to(working_device)

for name in submodules_to_process:
print(i, name, end=" ")
t0 = time.perf_counter()
print("collecting stats", end=" ")
sys.stdout.flush()
module = block.get_submodule(name)

sparsegpt = SparseGPT(
module,
sparsity=sparsity,
prunen=prunen,
prunem=prunem,
)

handle = module.register_forward_hook(sparsegpt.collect_input_stats)

for j in range(inps.size(0)):
outs[j : j + 1], _ = block(
inps[j : j + 1],
rope=rope_cache,
mask=mask_cache,
max_seq_length=model.config.block_size
)

handle.remove()


print("sparsifying", end=" ")
sys.stdout.flush()
error = sparsegpt.sparsify()

del sparsegpt
gc.collect()
torch.cuda.empty_cache()
t1 = time.perf_counter()
print(f"time {int(t1 - t0 + 0.5)}s sparsification error {error:.1f}")


for j in range(inps.size(0)):
outs[j : j + 1], _ = block(
inps[j : j + 1],
rope=rope_cache,
mask=mask_cache,
max_seq_length=model.config.block_size
)

block.cpu()
gc.collect()
torch.cuda.empty_cache()

inps, outs = outs, inps

model.transformer.ln_f.to(working_device)
for j in range(inps.size(0)):
outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1])
model.transformer.ln_f.to("cpu")

# normalised out will be input to the LM head
inps, outs = outs, inps

model.lm_head.to(working_device)
sparsegpt = SparseGPT(
model.lm_head,
sparsity=sparsity,
prunen=prunen,
prunem=prunem,
)

# During the forward pass, the collect_input_stats function collects input statistics and updates the Hessian matrix.
handle = model.lm_head.register_forward_hook(sparsegpt.collect_input_stats)
for j in range(inps.size(0)):
model.lm_head(inps[j : j + 1])
handle.remove()
# After the forward pass, the sparsify function can be called to perform the sparsification based on the collected statistics.
error = sparsegpt.sparsify()
model.lm_head.to("cpu")

def main(
*,
checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
output_path: Optional[Path] = None,
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
n_samples: int = 128,
dtype: str = "float32",
sparsity: float = 0.1,
prunem: int = 0,
prunen: int = 0
) -> None:
"""
Generates text samples based on a pre-trained LLaMA model and tokenizer.

Args:
checkpoint_path: The checkpoint path to load.
output_path: Path to write the sparsified model's state dict to.
tokenizer_path: The tokenizer path to load.
n_samples: Number of example inputs to use for statistics (default: 128)
dtype: The dtype to use to load the model.
sparsity: Target sparsity
prunem: M for N:M pruning.
prunen: N for N:M pruning.
"""
assert checkpoint_path.is_file()
assert tokenizer_path.is_file()
if output_path is None:
output_path = checkpoint_path.parent / "llama-gpt-sparsified.pth"
assert output_path.parent.is_dir() and (not output_path.exists() or output_path.is_file())

device = "cuda"

dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
dtype = dt


# we avoid loading the entire model on the GPU and do this block by block
with EmptyInitOnDevice(
device="cpu",
dtype=dtype,
):
print("Loading model ...", file=sys.stderr)
t0 = time.time()
checkpoint = torch.load(checkpoint_path)
name = llama_model_lookup(checkpoint)
model = LLaMA.from_name(name)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)

model.eval()

tokenizer = Tokenizer(tokenizer_path)

test_string = get_sample_data()
encoded_text = tokenizer.encode(
test_string,
bos=True,
eos=False,
)

block_size = 2048
# truncate the text and reshape to batch by sequence length
encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size)

t0 = time.perf_counter()
llama_blockwise_sparsification(
model=model,
sample_inputs=encoded_text,
working_device=device,
sparsity=sparsity,
prunen=prunen,
prunem=prunem
)
t = time.perf_counter() - t0

print(
f"\n\nTime for sparsification: {t:.02f} sec total",
file=sys.stderr,
)
print(
f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB",
file=sys.stderr,
)

torch.save(model.state_dict(), output_path)


if __name__ == "__main__":
from jsonargparse import CLI

torch.set_float32_matmul_precision("high")
CLI(main)