-
-
Notifications
You must be signed in to change notification settings - Fork 705
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: daquexian <[email protected]>
- Loading branch information
Showing
8 changed files
with
558 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#include "ATen/ATen.h" | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime.h> | ||
#include <torch/extension.h> | ||
|
||
#include "element_wise.h" | ||
#include "util.h" | ||
|
||
// Equivalent Python code: | ||
// ww = t_first + k | ||
// p = torch.maximum(pp, ww) | ||
// e1 = torch.exp(pp - p) | ||
// e2 = torch.exp(ww - p) | ||
// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) | ||
// ww = t_decay + pp | ||
// p = torch.maximum(ww, k) | ||
// e1 = torch.exp(ww - p) | ||
// e2 = torch.exp(k - p) | ||
// t1 = e1 * aa + e2 * v | ||
// t2 = e1 * bb + e2 | ||
// r = r * wkv | ||
// return t1, t2, p, r | ||
struct WkvForwardOne { | ||
const float *t_first; | ||
const float *k; | ||
const float *pp; | ||
const float *aa; | ||
const float *bb; | ||
const float *t_decay; | ||
const float *v; | ||
/* out */ float *t1; | ||
/* out */ float *t2; | ||
/* out */ float *p; | ||
/* in & out */ half *r; | ||
|
||
__device__ void operator()(int i) const { | ||
float ww = t_first[i] + k[i]; | ||
float pp_ = pp[i]; | ||
float p_ = (pp_ > ww) ? pp_ : ww; | ||
float e1 = expf(pp_ - p_); | ||
float e2 = expf(ww - p_); | ||
float aa_ = aa[i]; | ||
float bb_ = bb[i]; | ||
float v_ = v[i]; | ||
r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2)))); | ||
ww = t_decay[i] + pp_; | ||
float k_ = k[i]; | ||
p_ = (ww > k_) ? ww : k_; | ||
e1 = expf(ww - p_); | ||
e2 = expf(k_ - p_); | ||
t1[i] = e1 * aa_ + e2 * v_; | ||
t2[i] = e1 * bb_ + e2; | ||
p[i] = p_; | ||
} | ||
}; | ||
|
||
/* | ||
Equivalent Python code: | ||
kx = xx * k_mix + sx * (1 - k_mix) | ||
vx = xx * v_mix + sx * (1 - v_mix) | ||
rx = xx * r_mix + sx * (1 - r_mix) | ||
*/ | ||
|
||
struct Mix { | ||
const half *xx; | ||
const half *sx; | ||
const half *k_mix; | ||
const half *v_mix; | ||
const half *r_mix; | ||
/* out */ half *kx; | ||
/* out */ half *vx; | ||
/* out */ half *rx; | ||
|
||
__device__ void operator()(int i) const { | ||
half xx_ = xx[i]; | ||
half sx_ = sx[i]; | ||
half k_mix_ = k_mix[i]; | ||
half v_mix_ = v_mix[i]; | ||
half r_mix_ = r_mix[i]; | ||
kx[i] = __hadd(__hmul(xx_, k_mix_), | ||
__hmul(sx_, __hsub(__float2half(1), k_mix_))); | ||
vx[i] = __hadd(__hmul(xx_, v_mix_), | ||
__hmul(sx_, __hsub(__float2half(1), v_mix_))); | ||
rx[i] = __hadd(__hmul(xx_, r_mix_), | ||
__hmul(sx_, __hsub(__float2half(1), r_mix_))); | ||
} | ||
}; | ||
|
||
using torch::Tensor; | ||
|
||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); | ||
|
||
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, | ||
Tensor v_mix, Tensor r_mix, Tensor kw, | ||
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw, | ||
/* imm */ Tensor rx, Tensor ow, Tensor t_first, | ||
/* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb, | ||
Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r, | ||
/* out */ Tensor x_plus_out, /* out */ Tensor t1, | ||
/* out */ Tensor t2, /* out */ Tensor p) { | ||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); | ||
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx), | ||
data_ptr<half>(k_mix), data_ptr<half>(v_mix), | ||
data_ptr<half>(r_mix), data_ptr<half>(kx), | ||
data_ptr<half>(vx), data_ptr<half>(rx)}, | ||
x.numel()); | ||
|
||
gemm_fp16_cublas(kx, kw, k); | ||
gemm_fp16_cublas(vx, vw, v); | ||
gemm_fp16_cublas(rx, rw, r); | ||
at::sigmoid_(r); | ||
|
||
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k), | ||
data_ptr<float>(pp), data_ptr<float>(aa), | ||
data_ptr<float>(bb), data_ptr<float>(t_decay), | ||
data_ptr<float>(v), data_ptr<float>(t1), | ||
data_ptr<float>(t2), data_ptr<float>(p), | ||
data_ptr<half>(r)}, | ||
x.numel()); | ||
|
||
gemm_fp16_cublas(r, ow, x_plus_out); | ||
x_plus_out += x; | ||
return xx; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
#include "ATen/ATen.h" | ||
#include <cuda_fp16.h> | ||
#include <cuda_runtime.h> | ||
#include <torch/extension.h> | ||
|
||
#include "util.h" | ||
#include "element_wise.h" | ||
|
||
using torch::Tensor; | ||
|
||
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); | ||
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m, | ||
int n, int k, bool output_fp32); | ||
|
||
__global__ void kernel_wkv_forward_new( | ||
const int B, const int T, const int C, const float *__restrict__ const _w, | ||
const float *__restrict__ const _u, const float *__restrict__ const _k, | ||
const float *__restrict__ const _v, const half *__restrict__ const r, | ||
half *__restrict__ const _y, float *__restrict__ const _aa, | ||
float *__restrict__ const _bb, float *__restrict__ const _pp) { | ||
const int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
const int _b = idx / C; | ||
const int _c = idx % C; | ||
const int _offset = _b * T * C + _c; | ||
const int _state_offset = _b * C + _c; | ||
|
||
float u = _u[_c]; | ||
float w = _w[_c]; | ||
const float *__restrict__ const k = _k + _offset; | ||
const float *__restrict__ const v = _v + _offset; | ||
half *__restrict__ const y = _y + _offset; | ||
|
||
float aa = _aa[_state_offset]; | ||
float bb = _bb[_state_offset]; | ||
float pp = _pp[_state_offset]; | ||
for (int i = 0; i < T; i++) { | ||
const int ii = i * C; | ||
const float kk = k[ii]; | ||
const float vv = v[ii]; | ||
float ww = u + kk; | ||
float p = max(pp, ww); | ||
float e1 = exp(pp - p); | ||
float e2 = exp(ww - p); | ||
// y[ii] = __hmul(__float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)), | ||
// r[ii]); | ||
y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)); | ||
ww = w + pp; | ||
p = max(ww, kk); | ||
e1 = exp(ww - p); | ||
e2 = exp(kk - p); | ||
aa = e1 * aa + e2 * vv; | ||
bb = e1 * bb + e2; | ||
pp = p; | ||
} | ||
_aa[_state_offset] = aa; | ||
_bb[_state_offset] = bb; | ||
_pp[_state_offset] = pp; | ||
} | ||
|
||
void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k, | ||
float *v, half *r, half *y, float *aa, float *bb, | ||
float *pp) { | ||
dim3 threadsPerBlock(min(C, 32)); | ||
assert(B * C % threadsPerBlock.x == 0); | ||
dim3 numBlocks(B * C / threadsPerBlock.x); | ||
kernel_wkv_forward_new<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, r, | ||
y, aa, bb, pp); | ||
} | ||
|
||
__global__ void _att_mix(const half *xx, const half *sx, const half *k_mix, | ||
const half *v_mix, const half *r_mix, | ||
const int outer_size, const int inner_size, half *kx, | ||
half *vx, half *rx) { | ||
for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size; | ||
idx2 += blockDim.x * gridDim.x) { | ||
half k_mix_ = k_mix[idx2]; | ||
half v_mix_ = v_mix[idx2]; | ||
half r_mix_ = r_mix[idx2]; | ||
for (int row = 0; row < outer_size; ++row) { | ||
int idx1 = row * inner_size + idx2; | ||
half xx_ = xx[idx1]; | ||
half sx_ = sx[idx1]; | ||
kx[idx1] = __hadd(__hmul(xx_, k_mix_), | ||
__hmul(sx_, __hsub(__float2half(1), k_mix_))); | ||
vx[idx1] = __hadd(__hmul(xx_, v_mix_), | ||
__hmul(sx_, __hsub(__float2half(1), v_mix_))); | ||
rx[idx1] = __hadd(__hmul(xx_, r_mix_), | ||
__hmul(sx_, __hsub(__float2half(1), r_mix_))); | ||
} | ||
} | ||
} | ||
|
||
void att_mix(const half *xx, const half *sx, const half *k_mix, | ||
const half *v_mix, const half *r_mix, const int outer_size, | ||
const int inner_size, half *kx, half *vx, half *rx) { | ||
// 256 is good enough on most GPUs | ||
const int32_t BLOCK_SIZE = 256; | ||
assert(inner_size % BLOCK_SIZE == 0); | ||
_att_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>( | ||
xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx); | ||
} | ||
|
||
struct InplaceSigmoid { | ||
__device__ __forceinline__ half operator()(int i) const { | ||
ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i])))); | ||
} | ||
half *ptr; | ||
}; | ||
|
||
struct InplaceMul { | ||
__device__ __forceinline__ half operator()(int i) const { | ||
y[i] = __hmul(x[i], y[i]); | ||
} | ||
half *y; | ||
half *x; | ||
}; | ||
|
||
/* | ||
Equivalent Python code: | ||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) | ||
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) | ||
kx = xx * k_mix + sx * (1 - k_mix) | ||
vx = xx * v_mix + sx * (1 - v_mix) | ||
rx = xx * r_mix + sx * (1 - r_mix) | ||
r = torch.sigmoid(gemm(rx, rw)) | ||
k = gemm(kx, kw, output_dtype=torch.float32) | ||
v = gemm(vx, vw, output_dtype=torch.float32) | ||
T = x.shape[0] | ||
for t in range(T): | ||
kk = k[t] | ||
vv = v[t] | ||
ww = t_first + kk | ||
p = torch.maximum(pp, ww) | ||
e1 = torch.exp(pp - p) | ||
e2 = torch.exp(ww - p) | ||
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) | ||
ww = t_decay + pp | ||
p = torch.maximum(ww, kk) | ||
e1 = torch.exp(ww - p) | ||
e2 = torch.exp(kk - p) | ||
aa = e1 * aa + e2 * vv | ||
bb = e1 * bb + e2 | ||
pp = p | ||
out = gemm(r * sx, ow) | ||
return x + out, xx[-1,:], aa, bb, pp | ||
*/ | ||
Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, | ||
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, | ||
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb, | ||
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) { | ||
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); | ||
sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0); | ||
char* buf_ptr = (char*)buf.data_ptr(); | ||
half* kx = (half*)buf_ptr; | ||
half* vx = kx + x.numel(); | ||
half* rx = vx + x.numel(); | ||
half* wkv_y = rx + x.numel(); | ||
att_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix), | ||
data_ptr<half>(v_mix), data_ptr<half>(r_mix), xx.size(0), xx.size(1), | ||
kx, vx, rx); | ||
float* k = reinterpret_cast<float*>(wkv_y + x.numel()); | ||
float* v = k + x.size(0) * kw.size(1); | ||
half* r = reinterpret_cast<half*>(v + x.size(0) * vw.size(1)); | ||
|
||
gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true); | ||
gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true); | ||
gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false); | ||
element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1)); | ||
cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr<float>(t_decay), | ||
data_ptr<float>(t_first), k, v, r, | ||
wkv_y, data_ptr<float>(aa), | ||
data_ptr<float>(bb), data_ptr<float>(pp)); | ||
element_wise(InplaceMul{wkv_y, r}, x.numel()); | ||
gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false); | ||
x_plus_out += x; | ||
return xx; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#include <cassert> | ||
#include <cstddef> | ||
#include <cstdint> | ||
|
||
template <typename Func> __global__ void _element_wise(Func func, int n) { | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; | ||
i += blockDim.x * gridDim.x) { | ||
func(i); | ||
} | ||
} | ||
|
||
// NOTE: packed data type (e.g. float4) is a overkill for current sizes | ||
// (4096 in 7B model and 768 in 0.1B model), | ||
// and is not faster than the plain float version. | ||
template <typename Func> | ||
void element_wise(Func func, int n) { | ||
// 256 is good enough on most GPUs | ||
const int32_t BLOCK_SIZE = 256; | ||
assert(n % BLOCK_SIZE == 0); | ||
_element_wise<<<n / BLOCK_SIZE, BLOCK_SIZE>>>(func, n); | ||
} |
Oops, something went wrong.