From a1185324114f8b081f59a0141b8e46c641d1f8a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20Th=C3=BCning?= Date: Fri, 8 Mar 2024 00:23:28 +0100 Subject: [PATCH] add mistral 7b --- .gitignore | 1 + README.md | 7 +- convert.py | 70 +++++++++++++------- install.sh | 10 ++- src/lib.rs | 180 +++++++++++++++++++++++++++++----------------------- src/main.rs | 6 +- 6 files changed, 167 insertions(+), 107 deletions(-) diff --git a/.gitignore b/.gitignore index eb5a316..661a7da 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ target +env \ No newline at end of file diff --git a/README.md b/README.md index bf57c36..e3c53a5 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,14 @@ -Like grep but for natural language questions. Based on Mixtral 8x7B. ~15 tokens/s on Nvidia RTX 3070 with 8GB memory. +Like grep but for natural language questions. Based on Mistral 7B or 8x7B. ~15 tokens/s on Nvidia RTX 3070 with 8GB memory. # Installation ## Linux x86_64 -If nvidia driver that supports cuda 12.1 exists, it installs cuda version, else cpu version. It's ~48GB. +If nvidia driver that supports cuda 12.1 exists, it installs cuda version, else cpu version. Replace `small` with `large` to install Mixtral 8x7B. It's ~7GB or ~48GB. ```bash -curl https://raw.githubusercontent.com/moritztng/fltr/main/install.sh -o install.sh && bash install.sh && source ~/.bashrc +curl https://raw.githubusercontent.com/moritztng/fltr/main/install.sh -o install.sh && bash install.sh small && source ~/.bashrc ``` # Quickstart +Add `--large` for Mixtral 8x7B. ```bash fltr --file emails.txt --prompt "Is the following email spam? Email:" --batch-size 32 ``` diff --git a/convert.py b/convert.py index cf73405..074afa3 100644 --- a/convert.py +++ b/convert.py @@ -2,18 +2,21 @@ import numpy as np from argparse import ArgumentParser + def serialize_fp32(file, tensor): - """ writes one fp32 tensor to file that is open in wb mode """ + """writes one fp32 tensor to file that is open in wb mode""" d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() - b = struct.pack(f'{len(d)}f', *d) + b = struct.pack(f"{len(d)}f", *d) file.write(b) + def serialize_int8(file, tensor): - """ writes one int8 tensor to file that is open in wb mode """ + """writes one int8 tensor to file that is open in wb mode""" d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) - b = struct.pack(f'{len(d)}b', *d) + b = struct.pack(f"{len(d)}b", *d) file.write(b) + def quantize_serialize(f, w, group_size): """ takes a tensor and returns the Q8_0 quantized version @@ -21,54 +24,75 @@ def quantize_serialize(f, w, group_size): """ assert w.numel() % group_size == 0 ori_shape = w.shape - w = w.float() # convert to float32 + w = w.float() # convert to float32 w = w.reshape(-1, group_size) # find the max in each group wmax = torch.abs(w).max(dim=1).values # calculate the scaling factor such that float = quant * scale scale = wmax / 127.0 # scale into range [-127, 127] - quant = w / scale[:,None] + quant = w / scale[:, None] # round to nearest integer int8val = torch.round(quant).to(torch.int8) # dequantize by rescaling - fp32val = (int8val.float() * scale[:,None]).view(-1) + fp32val = (int8val.float() * scale[:, None]).view(-1) fp32valr = fp32val.reshape(-1, group_size) # calculate the max error in each group err = torch.abs(fp32valr - w).max(dim=1).values # find the max error across all groups maxerr = err.max().item() - + serialize_int8(f, int8val) serialize_fp32(f, scale) return maxerr -parser = ArgumentParser(prog="llama 2 converter") + +parser = ArgumentParser(prog="mistral converter") parser.add_argument("output_path", type=str) parser.add_argument("--checkpoint", type=str) parser.add_argument("--group-size", default=64, type=int) +parser.add_argument("--moe", action="store_true") +parser.add_argument("--cuda", action="store_true") args = parser.parse_args() -state_dict = torch.load(args.checkpoint, map_location="cpu", mmap=True) +state_dict = torch.load(args.checkpoint, map_location="cuda" if args.cuda else "cpu", mmap=True) with open(args.output_path, "wb") as f: - serialize_fp32(f, state_dict['norm.weight']) + serialize_fp32(f, state_dict["norm.weight"]) print("norm.weight") - err = quantize_serialize(f, state_dict['tok_embeddings.weight'], args.group_size) + err = quantize_serialize(f, state_dict["tok_embeddings.weight"], args.group_size) print(f"tok_embeddings.weight, error: {err}") - err = quantize_serialize(f, state_dict['output.weight'], args.group_size) + err = quantize_serialize(f, state_dict["output.weight"], args.group_size) print(f"output.weight, error: {err}") for i in range(32): - layer_prefix = f'layers.{i}.' + layer_prefix = f"layers.{i}." print(layer_prefix) - for name in ['attention_norm.weight', 'ffn_norm.weight']: + for name in ["attention_norm.weight", "ffn_norm.weight"]: serialize_fp32(f, state_dict[layer_prefix + name]) print(name) - for name in ['attention.wq.weight', 'attention.wk.weight', 'attention.wv.weight', 'attention.wo.weight', 'feed_forward.gate.weight']: - err = quantize_serialize(f, state_dict[layer_prefix + name], args.group_size) + for name in [ + "attention.wq.weight", + "attention.wk.weight", + "attention.wv.weight", + "attention.wo.weight", + ] + ( + ["feed_forward.gate.weight"] + if args.moe + else [ + "feed_forward.w1.weight", + "feed_forward.w2.weight", + "feed_forward.w3.weight", + ] + ): + err = quantize_serialize( + f, state_dict[layer_prefix + name], args.group_size + ) print(f"{name}, error: {err}") - for e in range(8): - expert_prefix = layer_prefix + f"feed_forward.experts.{e}." - print(expert_prefix) - for name in ['w1.weight', 'w2.weight', 'w3.weight']: - err = quantize_serialize(f, state_dict[expert_prefix + name], args.group_size) - print(f"{name}, error: {err}") + if args.moe: + for e in range(8): + expert_prefix = layer_prefix + f"feed_forward.experts.{e}." + print(expert_prefix) + for name in ["w1.weight", "w2.weight", "w3.weight"]: + err = quantize_serialize( + f, state_dict[expert_prefix + name], args.group_size + ) + print(f"{name}, error: {err}") diff --git a/install.sh b/install.sh index 62dd79d..5f903f2 100755 --- a/install.sh +++ b/install.sh @@ -10,7 +10,15 @@ fi INSTALL_DIR=~/Fltr mkdir -p "$INSTALL_DIR" curl -sSL https://github.com/moritztng/fltr/releases/download/v0.1-alpha/fltr-0.1-x86_64-${processor}.gz | gunzip > "$INSTALL_DIR/fltr" -curl -L https://huggingface.co/moritztng/Mixtral-8x7B-Instruct-v0.1/resolve/main/{weights.bin,tokenizer.json} -o "$INSTALL_DIR/weights.bin" -o "$INSTALL_DIR/tokenizer.json" + +MODEL_URL=https://huggingface.co/moritztng/fltr/resolve/main +curl -L "$MODEL_URL/tokenizer.json" -o "$INSTALL_DIR/tokenizer.json" +if [[ ",$1," == *",small,"* ]]; then + curl -L "$MODEL_URL/mistral-7b-instruct-v0.2.bin" -o "$INSTALL_DIR/small.bin" +fi +if [[ ",$1," == *",large,"* ]]; then + curl -L "$MODEL_URL/mixtral-8x7b-instruct-v0.1.bin" -o "$INSTALL_DIR/large.bin" +fi chmod +x "$INSTALL_DIR/fltr" if [[ ":$PATH:" != *":$INSTALL_DIR:"* ]]; then diff --git a/src/lib.rs b/src/lib.rs index 50977b8..da5670c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,7 +77,7 @@ struct Layer { heads: QuantizedSlice<'static>, rms_attention: &'static [f32], rms_feedforward: &'static [f32], - gate: QuantizedSlice<'static>, + gate: Option>, experts: Vec, } @@ -304,13 +304,13 @@ fn cache_kv(cache: &mut [f32], kv: &[f32], cache_lens: &[usize], kv_lens: &[usiz } impl Model { - pub fn from_dir(path: &Path) -> Model { + pub fn from_dir(path: &Path, multiple_experts: bool) -> Model { #[cfg(feature = "cuda")] unsafe { cuda_init() }; let mmap: MmapRaw = MmapOptions::new() - .map_raw_read_only(&File::open(path.join("weights.bin")).unwrap()) + .map_raw_read_only(&File::open(path.join(format!("{}.bin", if multiple_experts {"large"} else {"small"}))).unwrap()) .unwrap(); let mut weights_ptr = mmap.as_ptr() as *const u8; let rms_final = ptr_to_slice::(&mut weights_ptr); @@ -328,9 +328,15 @@ impl Model { let value = QuantizedSlice::from_ptr::<{ DIM * N_KV_HEADS * HEAD_SIZE }>(&mut weights_ptr); let heads = QuantizedSlice::from_ptr::<{ DIM * DIM }>(&mut weights_ptr); - let gate = QuantizedSlice::from_ptr::<{ DIM * N_EXPERTS }>(&mut weights_ptr); + let gate = if multiple_experts { + Some(QuantizedSlice::from_ptr::<{ DIM * N_EXPERTS }>( + &mut weights_ptr, + )) + } else { + None + }; let mut experts = Vec::new(); - for _ in 0..N_EXPERTS { + for _ in 0..(if multiple_experts { N_EXPERTS } else { 1 }) { experts.push(Expert { ff1: QuantizedSlice::from_ptr::<{ DIM * HIDDEN_DIM }>(&mut weights_ptr), ff2: QuantizedSlice::from_ptr::<{ HIDDEN_DIM * DIM }>(&mut weights_ptr), @@ -537,91 +543,109 @@ impl Model { &weights.rms_feedforward, DIM, ); - + quantize(&mut buffer.qstate, &buffer.state2); - matmul::( - &mut buffer.expert_logits, - &buffer.qstate.slice_full(), - &weights.gate, - ); - - let mut expert_tokens: [Vec<(usize, f32)>; 8] = Default::default(); - for (p, token_expert_logits) in - buffer.expert_logits.chunks_exact_mut(N_EXPERTS).enumerate() - { - let mut indices_logits: Vec<_> = token_expert_logits.iter().enumerate().collect(); - indices_logits - .sort_unstable_by(|(_, logit1), (_, logit2)| logit2.total_cmp(logit1)); - let (expert_indices, mut expert_weights): (Vec<_>, Vec<_>) = - indices_logits.into_iter().take(N_EXPERTS_PER_TOKEN).unzip(); - softmax(&mut expert_weights); - for (expert_index, expert_weight) in - expert_indices.iter().zip(expert_weights.iter()) + if let Some(gate_weights) = &weights.gate { + matmul::( + &mut buffer.expert_logits, + &buffer.qstate.slice_full(), + gate_weights, + ); + let mut expert_tokens: [Vec<(usize, f32)>; 8] = Default::default(); + for (p, token_expert_logits) in + buffer.expert_logits.chunks_exact_mut(N_EXPERTS).enumerate() { - expert_tokens[*expert_index].push((p, *expert_weight)); + let mut indices_logits: Vec<_> = + token_expert_logits.iter().enumerate().collect(); + indices_logits + .sort_unstable_by(|(_, logit1), (_, logit2)| logit2.total_cmp(logit1)); + let (expert_indices, mut expert_weights): (Vec<_>, Vec<_>) = + indices_logits.into_iter().take(N_EXPERTS_PER_TOKEN).unzip(); + softmax(&mut expert_weights); + for (expert_index, expert_weight) in + expert_indices.iter().zip(expert_weights.iter()) + { + expert_tokens[*expert_index].push((p, *expert_weight)); + } } - } - for (expert_index, token_weights) in expert_tokens.iter().enumerate() { - if token_weights.is_empty() { - continue; - } + for (expert_index, token_weights) in expert_tokens.iter().enumerate() { + if token_weights.is_empty() { + continue; + } - let expert = &weights.experts[expert_index]; - let n_tokens = token_weights.len(); - let expert_qstate = buffer.qstate2.slice_mut(0, n_tokens * DIM); - for ((state_values, state_scales), (token_index, _)) in expert_qstate - .values - .chunks_exact_mut(DIM) - .zip(expert_qstate.scales.chunks_exact_mut(DIM / Q_GROUP_SIZE)) - .zip(token_weights.iter()) - { - state_values.copy_from_slice( - &buffer - .qstate - .values - .chunks_exact(DIM) - .nth(*token_index) - .unwrap(), - ); - state_scales.copy_from_slice( - &buffer - .qstate - .scales - .chunks_exact(DIM / Q_GROUP_SIZE) - .nth(*token_index) - .unwrap(), + let expert = &weights.experts[expert_index]; + let n_tokens = token_weights.len(); + let expert_qstate = buffer.qstate2.slice_mut(0, n_tokens * DIM); + for ((state_values, state_scales), (token_index, _)) in expert_qstate + .values + .chunks_exact_mut(DIM) + .zip(expert_qstate.scales.chunks_exact_mut(DIM / Q_GROUP_SIZE)) + .zip(token_weights.iter()) + { + state_values.copy_from_slice( + &buffer + .qstate + .values + .chunks_exact(DIM) + .nth(*token_index) + .unwrap(), + ); + state_scales.copy_from_slice( + &buffer + .qstate + .scales + .chunks_exact(DIM / Q_GROUP_SIZE) + .nth(*token_index) + .unwrap(), + ); + } + let expert_qstate = buffer.qstate2.slice(0, n_tokens * DIM); + let expert_ff_hidden = &mut buffer.ff_hidden[..n_tokens * HIDDEN_DIM]; + let expert_swiglu = &mut buffer.swiglu[..n_tokens * HIDDEN_DIM]; + matmul::(expert_ff_hidden, &expert_qstate, &expert.ff1); + matmul::(expert_swiglu, &expert_qstate, &expert.swiglu); + for (hidden_x, swiglu_x) in + expert_ff_hidden.iter_mut().zip(expert_swiglu.iter()) + { + *hidden_x *= 1f32 / (1f32 + (-*hidden_x).exp()); + *hidden_x *= swiglu_x; + } + quantize(&mut buffer.qhidden, &expert_ff_hidden); + matmul::( + &mut buffer.state2[..n_tokens * DIM], + &buffer.qhidden.slice(0, n_tokens * HIDDEN_DIM), + &expert.ff2, ); + for (token_state, (token_index, weight)) in buffer.state2[..n_tokens * DIM] + .chunks_exact_mut(DIM) + .zip(token_weights.iter()) + { + smul(token_state, *weight); + add( + &mut buffer + .state + .chunks_exact_mut(DIM) + .nth(*token_index) + .unwrap(), + token_state, + ); + } } - let expert_qstate = buffer.qstate2.slice(0, n_tokens * DIM); - let expert_ff_hidden = &mut buffer.ff_hidden[..n_tokens * HIDDEN_DIM]; - let expert_swiglu = &mut buffer.swiglu[..n_tokens * HIDDEN_DIM]; - matmul::(expert_ff_hidden, &expert_qstate, &expert.ff1); - matmul::(expert_swiglu, &expert_qstate, &expert.swiglu); - for (hidden_x, swiglu_x) in expert_ff_hidden.iter_mut().zip(expert_swiglu.iter()) { + } else { + matmul::(&mut buffer.ff_hidden, &buffer.qstate.slice_full(), &weights.experts[0].ff1); + matmul::(&mut buffer.swiglu, &buffer.qstate.slice_full(), &weights.experts[0].swiglu); + for (hidden_x, swiglu_x) in buffer.ff_hidden.iter_mut().zip(buffer.swiglu.iter()) { *hidden_x *= 1f32 / (1f32 + (-*hidden_x).exp()); *hidden_x *= swiglu_x; } - quantize(&mut buffer.qhidden, &expert_ff_hidden); + quantize(&mut buffer.qhidden, &buffer.ff_hidden); matmul::( - &mut buffer.state2[..n_tokens * DIM], - &buffer.qhidden.slice(0, n_tokens * HIDDEN_DIM), - &expert.ff2, + &mut buffer.state2, + &buffer.qhidden.slice_full(), + &weights.experts[0].ff2, ); - for (token_state, (token_index, weight)) in buffer.state2[..n_tokens * DIM] - .chunks_exact_mut(DIM) - .zip(token_weights.iter()) - { - smul(token_state, *weight); - add( - &mut buffer - .state - .chunks_exact_mut(DIM) - .nth(*token_index) - .unwrap(), - token_state, - ); - } + add(&mut buffer.state, &buffer.state2); } } diff --git a/src/main.rs b/src/main.rs index 871d8ab..1f11c4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,8 @@ struct Args { #[arg(long, default_value = "32")] batch_size: Option, #[arg(long)] + large: bool, + #[arg(long)] debug: bool, #[command(subcommand)] command: Option, @@ -44,10 +46,10 @@ fn main() { autostop, }) = args.command { - let mut model = Model::from_dir(model_path); + let mut model = Model::from_dir(model_path, args.large); model.generate(&prompts, length - 1, true, autostop, None); } else { - let mut model = Model::from_dir(model_path); + let mut model = Model::from_dir(model_path, args.large); let (cache, _) = model.generate( &[format!("[INST] {}", args.prompt.unwrap())], 0,