Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generic sycl: refactor kernel mainloops #2070

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 31 additions & 45 deletions src/gpu/generic/sycl/binary_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,6 @@ struct binary_kernel_vec_t {
memory_plain_t src0_scale_mem(src0_scale_, scales_dt_);
memory_plain_t src1_scale_mem(src1_scale_, scales_dt_);

auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;

size_t base_idx = offset_t * conf_.block_size;
size_t vec_base_idx = base_idx / vec_len;

size_t sg_base_idx = (wg_offset_t + sg_offset_t) * conf_.block_size;

const float sm_0 = (conf_.do_scale_src0 ? src0_scale_mem.load(0) : 1.f);

const float sm_1 = (conf_.do_scale_src1 ? src1_scale_mem.load(0) : 1.f);
Expand Down Expand Up @@ -98,12 +87,12 @@ struct binary_kernel_vec_t {

if (!any_broadcast && !is_blocked_fmt
&& conf_.post_ops.get_post_op() == 0
&& sg_base_idx + (sg.get_local_range()[0] * conf_.block_size)
< conf_.wk_size
&& is_same_tag) {
for (int i = 0; i < conf_.block_size / vec_len; i++) {
auto src0_vec = src0_mem.load_vec<vec_len>(vec_base_idx + i);
auto src1_vec = src1_mem.load_vec<vec_len>(vec_base_idx + i);
&& conf_.wk_size % vec_len == 0 && is_same_tag) {
for (int vec_idx = item.get_global_id(0);
vec_idx < conf_.wk_size / vec_len;
vec_idx += item.get_global_range(0)) {
auto src0_vec = src0_mem.load_vec<vec_len>(vec_idx);
auto src1_vec = src1_mem.load_vec<vec_len>(vec_idx);

if (conf_.do_scale_src0)
src0_vec *= ::sycl::vec<float, vec_len>(sm_0);
Expand All @@ -114,37 +103,34 @@ struct binary_kernel_vec_t {
// TODO: Adding post-ops seems to be interfering with compiler's
// optimizations. Figure out how to make the compiler to generate
// the right code.
dst_mem.store_vec(acc_vec, vec_base_idx + i);
dst_mem.store_vec(acc_vec, vec_idx);
}
} else {
for (int i = 0; i < conf_.block_size; i++) {
int idx = base_idx + i;
if (idx < conf_.wk_size) {
auto l_offset = idx;
for (int i = 0; i < conf_.ndims; i++) {
const int d = conf_.ndims - 1 - i;
const dim_t cur_dim = conf_.dst_md.dims()[d];
off_dst[d] = l_offset % cur_dim;
l_offset = l_offset / cur_dim;
}

for (int i = 0; i < max_supported_ndims; i++) {
off0[i] = conf_.broadcast_dims0[i] ? 0 : off_dst[i];
off1[i] = conf_.broadcast_dims1[i] ? 0 : off_dst[i];
}

auto src0 = src0_mem.load_md(off0);
auto src1 = src1_mem.load_md(off1);

if (conf_.do_scale_src0) src0 *= sm_0;
if (conf_.do_scale_src1) src1 *= sm_1;

auto acc = compute_alg_n(src0, src1, conf_.alg_kind);

acc = conf_.post_ops.apply(
acc, dst_, idx, po_args_, off_dst);
dst_mem.store_md(acc, off_dst);
for (int idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
auto l_offset = idx;
for (int i = 0; i < conf_.ndims; i++) {
const int d = conf_.ndims - 1 - i;
const dim_t cur_dim = conf_.dst_md.dims()[d];
off_dst[d] = l_offset % cur_dim;
l_offset = l_offset / cur_dim;
}

for (int i = 0; i < max_supported_ndims; i++) {
off0[i] = conf_.broadcast_dims0[i] ? 0 : off_dst[i];
off1[i] = conf_.broadcast_dims1[i] ? 0 : off_dst[i];
}

auto src0 = src0_mem.load_md(off0);
auto src1 = src1_mem.load_md(off1);

if (conf_.do_scale_src0) src0 *= sm_0;
if (conf_.do_scale_src1) src1 *= sm_1;

auto acc = compute_alg_n(src0, src1, conf_.alg_kind);

acc = conf_.post_ops.apply(acc, dst_, idx, po_args_, off_dst);
dst_mem.store_md(acc, off_dst);
}
}
}
Expand Down
158 changes: 73 additions & 85 deletions src/gpu/generic/sycl/convolution_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ struct convolution_kernel_fwd_t {
: data_type_t::dnnl_f32) {}

void operator()(::sycl::nd_item<1> item) const {
auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;

size_t base_idx = offset_t * conf_.block_size;

const float sm_data = (conf_.do_scale_data
? load_float_value(scales_data_dt_, data_scale_ptr(), 0)
: 1.f);
Expand Down Expand Up @@ -132,94 +124,90 @@ struct convolution_kernel_fwd_t {
const int DH = conf_.dilation[1];
const int DW = conf_.dilation[2];

for (int i = 0; i < conf_.block_size; i++) {
int idx = base_idx + i;
if (idx < conf_.wk_size) {
for (int i = 0; i < max_supported_ndims; i++) {
off[i] = idx / dst_strides[i] % dst_dims[i];
}
for (int idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
for (int i = 0; i < max_supported_ndims; i++) {
off[i] = idx / dst_strides[i] % dst_dims[i];
}

const int n = off[0];
const int oc_tot = off[1];
const int oc = oc_tot % OC;
const int g = oc_tot / OC;

const int od = off[2];
const int oh = off[3];
const int ow = off[4];

float accumulator = 0;
for (int ic = 0; ic < IC; ++ic) {
for (int kd = 0; kd < KD; ++kd) {
for (int kh = 0; kh < KH; ++kh) {
for (int kw = 0; kw < KW; ++kw) {
const int id = od * SD - PD + kd * (1 + DD);
const int ih = oh * SH - PH + kh * (1 + DH);
const int iw = ow * SW - PW + kw * (1 + DW);

if (id < 0 || id >= data_dims[2] || ih < 0
|| ih >= data_dims[3] || iw < 0
|| iw >= data_dims[4]) {
continue;
}
const int n = off[0];
const int oc_tot = off[1];
const int oc = oc_tot % OC;
const int g = oc_tot / OC;

dims_t off_data {n, g * IC + ic, id, ih, iw};
const int data_idx = data_md().off_v(off_data);
dims_t off_weights {g, oc, ic, kd, kh, kw};
dims_t off_weights_no_groups {
oc, ic, kd, kh, kw};
const int weights_idx = weights_md().off_v(
no_groups ? off_weights_no_groups
: off_weights);

auto data = load_float_value(
data_md().data_type(), data_ptr(),
data_idx);
auto weight = load_float_value(
weights_md().data_type(), weights_ptr(),
weights_idx);

if (conf_.use_data_zeropoints) {
int zpoint_idx = conf_.single_data_zeropoint
? 0
: g * IC + ic;
auto data_zeropoint = load_float_value(
zeropoints_data_dt_,
data_zeropoint_ptr(), zpoint_idx);
data -= data_zeropoint;
}
accumulator += data * weight;
const int od = off[2];
const int oh = off[3];
const int ow = off[4];

float accumulator = 0;
for (int ic = 0; ic < IC; ++ic) {
for (int kd = 0; kd < KD; ++kd) {
for (int kh = 0; kh < KH; ++kh) {
for (int kw = 0; kw < KW; ++kw) {
const int id = od * SD - PD + kd * (1 + DD);
const int ih = oh * SH - PH + kh * (1 + DH);
const int iw = ow * SW - PW + kw * (1 + DW);

if (id < 0 || id >= data_dims[2] || ih < 0
|| ih >= data_dims[3] || iw < 0
|| iw >= data_dims[4]) {
continue;
}

dims_t off_data {n, g * IC + ic, id, ih, iw};
const int data_idx = data_md().off_v(off_data);
dims_t off_weights {g, oc, ic, kd, kh, kw};
dims_t off_weights_no_groups {oc, ic, kd, kh, kw};
const int weights_idx = weights_md().off_v(no_groups
? off_weights_no_groups
: off_weights);

auto data = load_float_value(data_md().data_type(),
data_ptr(), data_idx);
auto weight
= load_float_value(weights_md().data_type(),
weights_ptr(), weights_idx);

if (conf_.use_data_zeropoints) {
int zpoint_idx = conf_.single_data_zeropoint
? 0
: g * IC + ic;
auto data_zeropoint = load_float_value(
zeropoints_data_dt_,
data_zeropoint_ptr(), zpoint_idx);
data -= data_zeropoint;
}
accumulator += data * weight;
}
}
}
if (conf_.do_scale_data) { accumulator *= sm_data; }
if (conf_.do_scale_weights) {
if (!conf_.single_weight_scale) {
sm_weights = load_float_value(scales_weights_dt_,
weights_scale_ptr(), oc_tot);
}
accumulator *= sm_weights;
}
if (conf_.do_scale_data) { accumulator *= sm_data; }
if (conf_.do_scale_weights) {
if (!conf_.single_weight_scale) {
sm_weights = load_float_value(
scales_weights_dt_, weights_scale_ptr(), oc_tot);
}
accumulator *= sm_weights;
}

if (bias_md().ndims() != 0) {
auto bias = load_float_value(
bias_md().data_type(), bias_ptr(), oc_tot);
accumulator += bias;
}
if (bias_md().ndims() != 0) {
auto bias = load_float_value(
bias_md().data_type(), bias_ptr(), oc_tot);
accumulator += bias;
}

accumulator = conf_.post_ops.apply(accumulator, dst_, idx);
accumulator = conf_.post_ops.apply(accumulator, dst_, idx);

if (conf_.do_scale_dst) { accumulator /= sm_dst; }
if (conf_.use_dst_zeropoints) {
int zpoint_idx = conf_.single_dst_zeropoint ? 0 : oc_tot;
auto dst_zeropoint = load_float_value(zeropoints_dst_dt_,
dst_zeropoint_ptr(), zpoint_idx);
accumulator += dst_zeropoint;
}
store_float_value(
dst_md().data_type(), accumulator, dst_ptr(), idx);
if (conf_.do_scale_dst) { accumulator /= sm_dst; }
if (conf_.use_dst_zeropoints) {
int zpoint_idx = conf_.single_dst_zeropoint ? 0 : oc_tot;
auto dst_zeropoint = load_float_value(
zeropoints_dst_dt_, dst_zeropoint_ptr(), zpoint_idx);
accumulator += dst_zeropoint;
}
store_float_value(
dst_md().data_type(), accumulator, dst_ptr(), idx);
}
}

Expand Down
63 changes: 22 additions & 41 deletions src/gpu/generic/sycl/eltwise_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ struct eltwise_fwd_kernel_vec_t {
memory_tensor_t src_mem(src_, conf_.src_md);
memory_tensor_t dst_mem(dst_, conf_.dst_md);

auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;

size_t base_idx = offset_t * conf_.block_size;

auto operation = [&](dim_t &idx, dim_t &n, dim_t &c, dim_t &d, dim_t &h,
dim_t &w) {
dim_t src_offset = data_offset(src_mem.md(), n, c, d, h, w);
Expand All @@ -72,22 +64,20 @@ struct eltwise_fwd_kernel_vec_t {
dst_mem.store(acc, src_offset);
};

for (dim_t blk_idx = 0; blk_idx < conf_.block_size; blk_idx++) {
dim_t idx = base_idx + blk_idx;
if (idx < conf_.wk_size) {
dim_t N = conf_.mb;
dim_t C = conf_.c;
dim_t D = conf_.d;
dim_t H = conf_.h;
dim_t W = conf_.w;

dim_t n = (idx / (C * D * H * W)) % N;
dim_t c = (idx / (D * H * W)) % C;
dim_t d = (idx / (H * W)) % D;
dim_t h = (idx / (W)) % H;
dim_t w = (idx / (1)) % W;
operation(idx, n, c, d, h, w);
}
for (dim_t idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
dim_t N = conf_.mb;
dim_t C = conf_.c;
dim_t D = conf_.d;
dim_t H = conf_.h;
dim_t W = conf_.w;

dim_t n = (idx / (C * D * H * W)) % N;
dim_t c = (idx / (D * H * W)) % C;
dim_t d = (idx / (H * W)) % D;
dim_t h = (idx / (W)) % H;
dim_t w = (idx / (1)) % W;
operation(idx, n, c, d, h, w);
}
}

Expand Down Expand Up @@ -221,23 +211,14 @@ struct eltwise_bwd_kernel_vec_t {
memory_tensor_t diff_src_mem(diff_src_, conf_.diff_src_md);
memory_tensor_t diff_dst_mem(diff_dst_, conf_.diff_dst_md);

auto sg = item.get_sub_group();
size_t wg_offset_t = item.get_group(0) * conf_.wg_size;
size_t sg_offset_t = sg.get_group_id()[0] * sg.get_local_range()[0];
size_t wi_offset_t = sg.get_local_id();
size_t offset_t = wg_offset_t + sg_offset_t + wi_offset_t;
size_t base_idx = offset_t * conf_.block_size;

for (dim_t i = 0; i < conf_.block_size; i++) {
dim_t idx = base_idx + i;
if (idx < conf_.wk_size) {
auto diff_src = diff_src_mem.load(idx);
auto src = src_mem.load(idx);

auto dst = compute_alg_n(
diff_src, src, conf_.alpha, conf_.beta, conf_.alg_kind);
diff_dst_mem.store(dst, idx);
}
for (dim_t idx = item.get_global_id(0); idx < conf_.wk_size;
idx += item.get_global_range(0)) {
auto diff_src = diff_src_mem.load(idx);
auto src = src_mem.load(idx);

auto dst = compute_alg_n(
diff_src, src, conf_.alpha, conf_.beta, conf_.alg_kind);
diff_dst_mem.store(dst, idx);
}
}

Expand Down
Loading