Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Add buffer support for binary ops #9063

Merged
merged 4 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/api/containers/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,19 +245,19 @@ class vTensor final {
TextureLimits logical_limits;
// Contains the number of elements in the tensor according to the canonical
// sizes.
size_t numel;
int32_t numel;

friend class vTensor;

UniformData(
const std::vector<int64_t>& sizes,
const std::vector<int64_t>& strides,
const TextureLimits& logical_limits,
const size_t numel)
const size_t numel_ll)
: sizes_v(utils::make_whcn_ivec4(sizes)),
strides_v(utils::make_whcn_ivec4(strides)),
logical_limits(logical_limits),
numel(numel) {}
numel(utils::safe_downcast<int32_t>(numel_ll)) {}

public:
/*
Expand Down
76 changes: 62 additions & 14 deletions backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,83 @@
#define PRECISION ${PRECISION}

#define VEC4_T ${texel_type(DTYPE)}
#define T ${buffer_scalar_type(DTYPE)}

#define op(X, Y, A) ${OPERATOR}

${define_active_storage_type(STORAGE)}
${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}

$if STORAGE == "buffer":
layout(push_constant) uniform restrict Block {
ivec4 in_sizes;
ivec4 other_sizes;
ivec4 out_strides;
ivec4 in_strides;
ivec4 other_strides;
int out_numel;
float alpha;
};
$else:
layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 in_sizes;
ivec4 other_sizes;
ivec2 broadcast_params;
float alpha;
};

#include "broadcasting_utils.h"
#include "indexing_utils.h"

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
const lowp int packed_dim = unhash_packed_dim(out_layout);
$if STORAGE == "buffer":
${layout_declare_spec_const(C, "int", "out_packed_dim", "DEFAULT_LAYOUT")}
${layout_declare_spec_const(C, "int", "in_packed_dim", "DEFAULT_LAYOUT")}
${layout_declare_spec_const(C, "int", "other_packed_dim", "DEFAULT_LAYOUT")}
$else:
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
const lowp int packed_dim = unhash_packed_dim(out_layout);

${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);

${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 other_axis_map = unhash_axis_map(other_layout);
${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 other_axis_map = unhash_axis_map(other_layout);

layout(push_constant) uniform restrict Block {
ivec4 out_sizes;
ivec4 in_sizes;
ivec4 other_sizes;
ivec2 broadcast_params;
float alpha;
};
#ifdef USING_BUFFER

void main() {
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
if (out_bufi >= out_numel) {
return;
}

// Simple case; no broadcasting
if (in_sizes == other_sizes) {
t_out[out_bufi] = T(op(t_in[out_bufi], t_other[out_bufi], T(alpha)));
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_packed_dim);
const ivec4 in_tidx = min(out_tidx, in_sizes - 1);
const ivec4 other_tidx = min(out_tidx, other_sizes - 1);

const int in_bufi = tidx_to_bufi(in_tidx, in_strides);
const int other_bufi = tidx_to_bufi(other_tidx, other_strides);

t_out[out_bufi] = T(op(t_in[in_bufi], t_other[other_bufi], T(alpha)));
}

#else // USING_TEXTURE

void main() {
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
Expand Down Expand Up @@ -79,3 +125,5 @@ void main() {
VEC4_T(op(in_texel, other_texel, alpha)),
out_axis_map);
}

#endif
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ binary_op:
NDIM: 3
DTYPE: float
PACKING: C_packed
STORAGE: texture3d
generate_variant_forall:
STORAGE:
- VALUE: texture3d
- VALUE: buffer
DTYPE:
- VALUE: half
- VALUE: float
Expand Down
64 changes: 63 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void resize_binary_op_node(
out->virtual_resize(new_out_sizes);
}

void add_binary_op_node(
void add_binary_op_texture_node(
ComputeGraph& graph,
const ValueRef in1,
const ValueRef in2,
Expand Down Expand Up @@ -75,6 +75,7 @@ void add_binary_op_node(
std::string kernel_name("binary_");
kernel_name.reserve(kShaderNameReserve);
kernel_name += op_name;
add_storage_type_suffix(kernel_name, *t_out);
add_dtype_suffix(kernel_name, *t_out);

graph.execute_nodes().emplace_back(new DispatchNode(
Expand All @@ -98,6 +99,67 @@ void add_binary_op_node(
PushConstantDataInfo(&binary_ops_params, sizeof(binary_ops_params))}}));
}

void add_binary_op_buffer_node(
ComputeGraph& graph,
const ValueRef in1,
const ValueRef in2,
const ValueRef alpha,
const ValueRef out,
const std::string& op_name) {
// check_binary_op_args(*t_in1, *t_in2, *t_out);

float alpha_val = 1.0f;
// String is checked since floor_div passes in an unused string argument in
// place of alpha
if (is_valid(alpha) && !graph.val_is_string(alpha)) {
alpha_val = graph.extract_scalar<float>(alpha);
}

std::string kernel_name("binary_");
kernel_name.reserve(kShaderNameReserve);
kernel_name += op_name;
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{in1, in2}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
{},
// Specialization Constants
{graph.packed_dim_of(out), graph.packed_dim_of(in1), graph.packed_dim_of(in2)},
// Resizing Logic
resize_binary_op_node,
{},
{{graph.sizes_pc_of(in1),
graph.sizes_pc_of(in2),
graph.strides_pc_of(out),
graph.strides_pc_of(in1),
graph.strides_pc_of(in2),
graph.numel_pc_of(out),
PushConstantDataInfo(&alpha_val, sizeof(float)),
}}));
}

void add_binary_op_node(
ComputeGraph& graph,
const ValueRef in1,
const ValueRef in2,
const ValueRef alpha,
const ValueRef out,
const std::string& op_name) {
if (graph.is_buffer_storage(out)) {
add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name);
} else {
add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name);
}
}

#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
return add_binary_op_node( \
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_binary_elementwise_inputs():
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"]
return test_suite


Expand Down
Loading