Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast fp16 seq and one mode #157

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
140 changes: 140 additions & 0 deletions benchmark-all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#!/usr/bin/env python3

import subprocess
import wandb
import argparse
import json
import time
import traceback
import os

parser = argparse.ArgumentParser(description='Run benchmark')
parser.add_argument('--branch', type=str, help='branch of ChatRWKV', default='main')
parser.add_argument('--model', type=str, help='Model path', required=True)
parser.add_argument('--verbose', action='store_true', help='Print command output')
parser.add_argument('-n', type=int, help='Number of runs', required=True)
args = parser.parse_args()

models = [args.model]

strategies = ['bf16', 'fp16', 'fp32', 'fp16i8']

columns = ['Device'] + strategies

local_device = '2080'

vast_id = {}

vast_dev_names = {'1080': 'GTX_1080', '2080': 'RTX_2080', '3080': 'RTX_3080', '4090': 'RTX_4090'}


class NoInstanceError(RuntimeError):
pass


def prepare_vastai_env(device: str):
vast_device_name = vast_dev_names[device]
output = check_output(["vastai", "search", "offers", f"gpu_name={vast_device_name} cuda_vers>=11.8", "--raw"], args.verbose)
output = json.loads(output)
if len(output) == 0:
raise NoInstanceError(f"No Vast.ai offers found for {device}")
best = output[0]["id"]
print(f"Found best offer {best}")
output = check_output(f"vastai create instance {best} --image daquexian/cuda-pytorch:cu118-dev-2.0.1 --disk 32 --raw".split(), args.verbose)
output = json.loads(output)
instance_id = output["new_contract"]
print(f"Created instance {instance_id}, checking status..")
flag = False
while not flag:
time.sleep(10)
print("Checking status..")
# too verbose
output = check_output(f"vastai show instances --raw".split(), False)
output = json.loads(output)
for instance in output:
if instance["id"] == instance_id:
print(f"Instance {instance_id} is {instance['actual_status']}")
if instance["actual_status"] == "running":
vast_id[device] = (f'root@{instance["ssh_host"]}', instance["ssh_port"], instance_id)
flag = True
# sleep for a while to make sure the instance is ready
time.sleep(5)
break

ssh_prefix = f'ssh -o StrictHostKeyChecking=no -p {vast_id[device][1]} {vast_id[device][0]}'.split()
check_output(ssh_prefix + 'git clone https://github.com/BlinkDL/ChatRWKV'.split(), args.verbose)
if args.branch != 'main':
if '/' in args.branch:
user, branch = args.branch.split('/')
check_output(ssh_prefix + [f'cd ChatRWKV && git remote add daquexian https://github.com/{user}/ChatRWKV && git fetch {user}'], args.verbose)
check_output(ssh_prefix + [f'cd ChatRWKV && git checkout {args.branch}'], args.verbose)
check_output(ssh_prefix + 'pip install numpy'.split(), args.verbose)
check_output(ssh_prefix + 'apt install ninja-build'.split(), args.verbose)

scp('benchmark-custom.py', f'ChatRWKV/v2/benchmark-custom.py', vast_id[device][0], vast_id[device][1])
return ssh_prefix


wandb.init()

table = wandb.Table(columns=columns)


def scp(src, dst, dst_ip, dst_port):
print(f"scp from {src} to {dst} of {dst_ip}:{dst_port}")
subprocess.run(['scp', '-o', 'StrictHostKeyChecking=no', '-P', str(dst_port), src, f'{dst_ip}:{dst}'], stderr=subprocess.STDOUT)


def check_output(command, print_output):
print(f'Running {" ".join(command)}')
proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
stdout = ""
for line in proc.stdout:
if print_output:
print(line.decode('utf-8').strip())
stdout += line.decode('utf-8')
assert proc.wait() == 0, f"Command {' '.join(command)} failed with stdout {stdout}"
return stdout.strip()


for device in ['4090', '3080', '2080', '1080', 'cpu']:
if device in ['cpu', local_device]:
ssh_prefix = []
project_dir = os.path.expanduser('~/files/repos/ChatRWKV')
else:
try:
ssh_prefix = prepare_vastai_env(device)
except NoInstanceError:
print(f"No instance found for {device}, skipping")
continue
except Exception as e:
import pdb; pdb.set_trace()
traceback.print_exc()
project_dir = 'ChatRWKV/'
device_type = 'cpu' if device == 'cpu' else 'cuda'
for model in models:
if device in vast_id:
scp(model, f'ChatRWKV/{model}', vast_id[device][0], vast_id[device][1])
data = [device]
for strategy in strategies:
for mode in ['slow']:
try:
latency = 99999999999
for _ in range(args.n):
command = [*ssh_prefix, 'python3', f'{project_dir}v2/benchmark-custom.py', '--model', f'{project_dir}{model}', '--strategy', f'{device_type}@{strategy}', '--custom-cuda-op', '--jit', f'--only-{mode}']
print(f'Running: {" ".join(command)}')
output = check_output(command, print_output=args.verbose)
latency = min(latency, float(output.splitlines()[-2].split(' ')[2][:-2]))
mem = float(output.splitlines()[-1].split(' ')[-2])
data.append(f'{latency * 1000:.0f}ms/{mem:.0f}MB') # type: ignore[reportUnboundVariable]
except:
data.append('N/A')
print(f'Failed to run {model} on {device} with {strategy}')
table.add_data(*data)
if device in vast_id:
check_output(['vastai', 'destroy', 'instance', str(vast_id[device][2])], args.verbose)
del vast_id[device]

wandb.log({'Latency and Memory': table})

wandb.finish()
4 changes: 2 additions & 2 deletions rwkv_pip_package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rwkv"
version = "0.8.9"
name = "rwkv-music-internal"
version = "0.8.5"
authors = [
{ name="Bo PENG" },
]
Expand Down
126 changes: 126 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/att_one.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "element_wise.h"
#include "util.h"

namespace {
// Equivalent Python code:
// ww = t_first + k
// p = torch.maximum(pp, ww)
// e1 = torch.exp(pp - p)
// e2 = torch.exp(ww - p)
// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
// ww = t_decay + pp
// p = torch.maximum(ww, k)
// e1 = torch.exp(ww - p)
// e2 = torch.exp(k - p)
// t1 = e1 * aa + e2 * v
// t2 = e1 * bb + e2
// r = r * wkv
// return t1, t2, p, r
struct WkvForwardOne {
const float *t_first;
const float *k;
const float *pp;
const float *aa;
const float *bb;
const float *t_decay;
const float *v;
/* out */ float *t1;
/* out */ float *t2;
/* out */ float *p;
/* in & out */ half *r;

__device__ void operator()(int i) const {
float ww = t_first[i] + k[i];
float pp_ = pp[i];
float p_ = (pp_ > ww) ? pp_ : ww;
float e1 = expf(pp_ - p_);
float e2 = expf(ww - p_);
float aa_ = aa[i];
float bb_ = bb[i];
float v_ = v[i];
r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2))));
ww = t_decay[i] + pp_;
float k_ = k[i];
p_ = (ww > k_) ? ww : k_;
e1 = expf(ww - p_);
e2 = expf(k_ - p_);
t1[i] = e1 * aa_ + e2 * v_;
t2[i] = e1 * bb_ + e2;
p[i] = p_;
}
};

/*
Equivalent Python code:
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
*/

struct Mix {
const half *xx;
const half *sx;
const half *k_mix;
const half *v_mix;
const half *r_mix;
/* out */ half *kx;
/* out */ half *vx;
/* out */ half *rx;

__device__ void operator()(int i) const {
half xx_ = xx[i];
half sx_ = sx[i];
half k_mix_ = k_mix[i];
half v_mix_ = v_mix[i];
half r_mix_ = r_mix[i];
kx[i] = __hadd(__hmul(xx_, k_mix_),
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
vx[i] = __hadd(__hmul(xx_, v_mix_),
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
rx[i] = __hadd(__hmul(xx_, r_mix_),
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
};
}

using torch::Tensor;

void gemm_cublas_tensor(const Tensor& a, const Tensor& b, const Tensor& c);

Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
Tensor v_mix, Tensor r_mix, Tensor kw,
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw,
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
/* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb,
Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r,
/* out */ Tensor x_plus_out, /* out */ Tensor t1,
/* out */ Tensor t2, /* out */ Tensor p) {
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
data_ptr<half>(r_mix), data_ptr<half>(kx),
data_ptr<half>(vx), data_ptr<half>(rx)},
x.numel());

gemm_cublas_tensor(kx, kw, k);
gemm_cublas_tensor(vx, vw, v);
gemm_cublas_tensor(rx, rw, r);
at::sigmoid_(r);

element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
data_ptr<float>(pp), data_ptr<float>(aa),
data_ptr<float>(bb), data_ptr<float>(t_decay),
data_ptr<float>(v), data_ptr<float>(t1),
data_ptr<float>(t2), data_ptr<float>(p),
data_ptr<half>(r)},
x.numel());

gemm_cublas_tensor(r, ow, x_plus_out);
x_plus_out += x;
return xx;
}
121 changes: 121 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/att_one_v5.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#include "element_wise.h"
#include "util.h"

namespace {
// Equivalent Python code:
// s1 = t_first * a + s
// s2 = a + t_decay * s
struct Fused1 {
const float *t_first;
const float *t_decay;
const float *a;
const float *s;
const int32_t inner_size;
/* out */ float *s1;
/* out */ float *s2;

__device__ void operator()(int i) const {
const int j = i / inner_size;
s1[i] = t_first[j] * a[i] + s[i];
s2[i] = a[i] + t_decay[j] * s[i];
}
};

/*
Equivalent Python code:
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
*/

struct Mix {
const half *xx;
const half *sx;
const half *kvr_mix;
const int stride;
/* out */ half *kvrx;

__device__ void operator()(int i) const {
half xx_ = xx[i];
half sx_ = sx[i];
half k_mix_ = kvr_mix[i];
half v_mix_ = kvr_mix[i + stride];
half r_mix_ = kvr_mix[i + stride * 2];
kvrx[i] = __hadd(__hmul(xx_, k_mix_),
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
kvrx[i + stride] = __hadd(__hmul(xx_, v_mix_),
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
kvrx[i + stride * 2] = __hadd(__hmul(xx_, r_mix_),
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
};

struct ToHalf {
const float *x;
half *y;
__device__ void operator()(int i) const { y[i] = __float2half(x[i]); }
};

struct InplaceAdd {
__device__ __forceinline__ half operator()(int i) const {
y[i] = __hadd(x[i], y[i]);
}
half *y;
half *x;
};
} // namespace

using torch::Tensor;

void gemm_cublas_tensor(const Tensor &a, const Tensor &b, const Tensor &c);
void gemm_cublas(const void *a, const void *b, void *c, int batch, int ori_m,
int ori_n, int ori_k, at::ScalarType torch_input_dtype,
at::ScalarType torch_output_dtype);

Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
Tensor lx_w, Tensor lx_b, Tensor kvr_mix, Tensor kvrw,
Tensor ow, Tensor t_first, Tensor t_decay, Tensor tmp,
Tensor buf, /* out */ Tensor s2_t,
/* out */ Tensor x_plus_out_t) {
const int x_numel = x.numel();
Tensor xx = at::layer_norm(x, {x_numel}, ln_w, ln_b);
int H = t_decay.size(0);
int S = x_numel / H;
char *buf_ptr = (char *)buf.data_ptr();
half *kvrx = (half *)buf_ptr;
float *kvr = (float *)(kvrx + 3 * x_numel);
float *a = kvr + 3 * x_numel;
half *tmp2 = (half *)(a + H * S * S);
float *s1 = (float *)(tmp2 + x_numel);
float *s2 = data_ptr<float>(s2_t);
half *x_plus_out = data_ptr<half>(x_plus_out_t);

element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
data_ptr<half>(kvr_mix), static_cast<int>(x_numel), kvrx},
x_numel);

gemm_cublas(kvrx, data_ptr<half>(kvrw), kvr, 3, 1, x_numel, x_numel,
at::kHalf, at::kFloat);
float *k = kvr;
float *v = k + x_numel;
float *r = v + x_numel;

gemm_cublas(k, v, a, H, S, S, 1, at::kFloat, at::kFloat);
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay), a,
data_ptr<float>(s), static_cast<int32_t>(S * S), s1, s2},
H * S * S);

gemm_cublas(r, s1, data_ptr<float>(tmp), H, 1, S, S, at::kFloat, at::kFloat);
tmp = at::group_norm(tmp, H, lx_w, lx_b);
element_wise(ToHalf{data_ptr<float>(tmp), tmp2}, tmp.numel());

gemm_cublas(tmp2, data_ptr<half>(ow), x_plus_out, 1, 1, x_numel, x_numel,
at::kHalf, at::kHalf);
element_wise(InplaceAdd{x_plus_out, data_ptr<half>(x)}, x.numel());
return xx;
}
Loading