From c02d0862872a18fc1bd6bbe09af3135b1847ea61 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Mon, 17 Mar 2025 13:11:12 -0700 Subject: [PATCH] [ET-VK] Adding all tensor packing support to cat op. This diff updates Executorch Vulkan backend's cat operation to support width, height and channel packed tensors. It also updates the op_registry.py file to indicate cat operation supports all packing and adds new test cases to the cases.py file to test the operation. Differential Revision: [D71230768](https://our.internmc.facebook.com/intern/diff/D71230768/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 3 +- .../runtime/graph/ops/glsl/copy_offset.glsl | 19 +++- .../ops/glsl/copy_packed_dim_offset.glsl | 106 ++++++++++++++++++ .../ops/glsl/copy_packed_dim_offset.yaml | 12 ++ .../vulkan/runtime/graph/ops/impl/Cat.cpp | 97 ++++++++-------- .../vulkan/runtime/graph/ops/impl/Copy.cpp | 92 +++++++++++++-- backends/vulkan/runtime/graph/ops/impl/Copy.h | 21 +++- .../vulkan/runtime/graph/ops/impl/Repeat.cpp | 16 +-- .../vulkan/runtime/graph/ops/impl/Split.cpp | 12 +- backends/vulkan/test/op_tests/cases.py | 11 ++ 10 files changed, 311 insertions(+), 78 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 57b0b828a27..f2b80c2e544 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -528,7 +528,6 @@ def register_view_op(features: OpFeatures): exir_ops.edge.aten.index_select.default, exir_ops.edge.aten.select_copy.int, # Tensor combination - exir_ops.edge.aten.cat.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split.Tensor, exir_ops.edge.aten.repeat.default, @@ -562,6 +561,8 @@ def register_ported_op(features: OpFeatures): exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.permute_copy.default, + # Tensor combination + exir_ops.edge.aten.cat.default, ] ) def register_ported_op_all_packed_dims(features: OpFeatures): diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl index a42a592762b..a23822765a3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl @@ -19,8 +19,10 @@ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} layout(push_constant) uniform restrict Block { ivec3 range; - ivec3 src_offset; - ivec3 dst_offset; + // xyz is source offset w is channel size + ivec4 src_offset; + // xyz is destination offset w is channel size + ivec4 dst_offset; }; #include "indexing_utils.h" @@ -36,13 +38,20 @@ const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec3 out_pos = pos + dst_offset; - const ivec3 in_pos = pos + src_offset; - if (any(greaterThanEqual(pos, range))) { return; } + const ivec3 in_pos = pos + src_offset.xyz; + ivec3 out_pos = pos + dst_offset.xyz; + + // If source channel size is specified compose output z based on channel and batch index + if (src_offset.w > 0) { + const int channel_index = in_pos.z % src_offset.w; + const int batch_index = in_pos.z / src_offset.w; + out_pos.z = channel_index + dst_offset.z + batch_index * dst_offset.w; + } + write_texel_lpos( t_out, out_pos, diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl new file mode 100644 index 00000000000..02ea6405b4a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.glsl @@ -0,0 +1,106 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "existing_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} + +layout(push_constant) uniform restrict Block { + ivec4 range; + // xyz is source offset w is channel size + ivec4 src_offset; + // xyz is destination offset w is channel size + ivec4 dst_offset; +}; + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); +const lowp int packed_dim = unhash_packed_dim(out_layout); + +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, range.xyz))) { + return; + } + + // Starting offset to write at within a texel + const int out_lane_offset = dst_offset[packed_dim] & 0x3; + const bool has_lane_offset = out_lane_offset != 0; + + // Position in input tensor + const ivec3 in_pos = pos + src_offset.xyz; + + // Read input value mapping to this output texel + const VEC4_T in_value = load_texel_lpos(t_in, in_pos, in_axis_map); + + ivec3 out_pos = pos + dst_offset.xyz; + out_pos[packed_dim] = pos[packed_dim] + (dst_offset[packed_dim] >> 2); + + VEC4_T out_value; + + // If lane offset is non zero i.e packed texel is composed from multiple sources + if (has_lane_offset) { + // When position in packed dim is > 0 + if (pos[packed_dim] > 0) { + // Boundary values will come from previous input texel in the packed dim. + ivec3 prev_in_pos = in_pos; + prev_in_pos[packed_dim] = in_pos[packed_dim] - 1; + VEC4_T prev_value = load_texel_lpos(t_in, prev_in_pos, in_axis_map); + + // Shift values toward the beginning based on out_lane_offset + // offset 1 means the last lane from the previous texel is a part of the output texel + // offset 2 means last 2 lanes and so on + if (out_lane_offset == 1) { + out_value.x = prev_value.w; + } else if (out_lane_offset == 2) { + out_value.xy = prev_value.zw; + } else { + out_value.xyz = prev_value.yzw; + } + } else { + // When position in packed dim is == 0 + // Boundary values will be the previous texel values. + out_value = load_texel_lpos(existing_out, out_pos, out_axis_map); + } + + // Copy input values towards the end of output array, based on lane offset + // offset 1 means the first lane from previous texel is part of the output texel starting at offset + // offset 2 means first 2 lanes from the previous texel is part of the output texel and so on + if (out_lane_offset == 1) { + out_value.yzw = in_value.xyz; + } else if (out_lane_offset == 2) { + out_value.zw = in_value.xy; + } else { + out_value.w = in_value.x; + } + } else { + out_value = in_value; + } + + write_texel_lpos( + t_out, + out_pos, + out_value, + out_axis_map); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml new file mode 100644 index 00000000000..e872d64e3c3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml @@ -0,0 +1,12 @@ +copy_packed_dim_offset: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int + shader_variants: + - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp index d5cfd5f4505..5f172454121 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp @@ -22,65 +22,68 @@ void add_cat_default_node( ValueRef dim_ref, ValueRef out) { ValueListPtr input_list = graph.get_value_list(in_list_ref); - - for (ValueRef input_ref : *input_list) { - vTensorPtr t_in = graph.get_tensor(input_ref); - VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim)); - } - int64_t dim = graph.extract_scalar(dim_ref); vTensorPtr t_out = graph.get_tensor(out); + const auto packed_dim = t_out->packed_dim(); + const auto packed_dim_index = static_cast(kWidth4D - packed_dim); + DimIndex dim_index = normalize_to_dim_index(*t_out, dim); + // Index of dimension to be concatenated in (w, h, c * b) coordinate system + const auto dim_xyz_index = std::min(2, -dim_index - 1); - // TODO: Find ways to factor out the similar code for width, height, and batch - if (dim_index == kWidth4D) { - utils::ivec3 src_offset = utils::make_ivec3({0, 0, 0}, false); - utils::ivec3 dst_offset = utils::make_ivec3({0, 0, 0}, false); + if (dim_index > kWidth4D || dim_index < kBatch4D) { + VK_THROW("Unexpected value of dim_index=", dim_index); + } - for (ValueRef input_ref : *input_list) { - vTensorPtr t_in = graph.get_tensor(input_ref); - utils::ivec3 range = t_in->logical_limits(); - add_copy_offset_node( - graph, input_ref, range, src_offset, dst_offset, out); - dst_offset[0] += range[0]; - } + utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false); + utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false); - } else if (dim_index == kHeight4D) { - utils::ivec3 src_offset = utils::make_ivec3({0, 0, 0}, false); - utils::ivec3 dst_offset = utils::make_ivec3({0, 0, 0}, false); + const bool is_concat_channel = (dim_index == kChannel4D); - for (ValueRef input_ref : *input_list) { - vTensorPtr t_in = graph.get_tensor(input_ref); - utils::ivec3 range = t_in->logical_limits(); - add_copy_offset_node( - graph, input_ref, range, src_offset, dst_offset, out); - dst_offset[1] += range[1]; - } - } else if (dim_index == kBatch4D) { - utils::ivec3 src_offset = utils::make_ivec3({0, 0, 0}, false); - utils::ivec3 dst_offset = utils::make_ivec3({0, 0, 0}, false); + // if concatenating channels + if (is_concat_channel) { + // set destination offset w as channel size of the output tensor + dst_offset[3] = dim_at(t_out->sizes(), kChannel4D); + } - for (ValueRef input_ref : *input_list) { - vTensorPtr t_in = graph.get_tensor(input_ref); - utils::ivec3 range = t_in->logical_limits(); + for (ValueRef input_ref : *input_list) { + const vTensorPtr t_in = graph.get_tensor(input_ref); + const utils::ivec3 range = t_in->logical_limits(); + const auto in_channel_size = dim_at(t_in->sizes(), kChannel4D); + // if concatenating same dimension as the packed dimension + if (dim_index == packed_dim_index) { + // if concatenating channels, use add_copy_channel_offset_node function as + // add_copy_packed_dim_offset_node does not support channel packing + if (is_concat_channel) { + add_copy_channel_offset_node( + graph, + input_ref, + in_channel_size, + src_offset[2], + dst_offset[2], + out); + dst_offset[dim_xyz_index] += in_channel_size; + } else { + // src_offset[3] is not used now but will be used in the future when + // add_copy_packed_dim_offset_node will support channel packing + // + // set source offset w as channel size of the output tensor if + // concatenating channels + src_offset[3] = is_concat_channel ? in_channel_size : 0; + add_copy_packed_dim_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset[dim_xyz_index] += dim_at(t_in->sizes(), packed_dim_index); + } + } else { + // set source offset w as channel size of the output tensor if + // concatenating channels + src_offset[3] = is_concat_channel ? in_channel_size : 0; add_copy_offset_node( graph, input_ref, range, src_offset, dst_offset, out); - dst_offset[2] += range[2]; + dst_offset[dim_xyz_index] += + is_concat_channel ? in_channel_size : range[dim_xyz_index]; } - } else if (dim_index == kChannel4D) { - int32_t src_offset = 0; - int32_t dst_offset = 0; - - for (ValueRef input_ref : *input_list) { - vTensorPtr t_in = graph.get_tensor(input_ref); - int32_t range = dim_at(t_in->sizes(), kChannel4D); - add_copy_channel_offset_node( - graph, input_ref, range, src_offset, dst_offset, out); - dst_offset += range; - } - } else { - VK_THROW("Unexpected value of dim_index=", dim_index); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp index 69378524afb..4b09fbe8619 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -16,14 +16,15 @@ namespace vkcompute { using utils::ivec3; +using utils::ivec4; using utils::uvec3; void add_copy_offset_node( ComputeGraph& graph, const ValueRef in, const ivec3& range, - const ivec3& src_offset, - const ivec3& dst_offset, + const ivec4& src_offset, + const ivec4& dst_offset, const ValueRef out) { vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); @@ -52,11 +53,81 @@ void add_copy_offset_node( nullptr, {}, { - PushConstantDataInfo(&range, sizeof(range), sizeof(utils::ivec4)), - PushConstantDataInfo( - &src_offset, sizeof(src_offset), sizeof(utils::ivec4)), + PushConstantDataInfo(&range, sizeof(range), sizeof(ivec4)), + PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)), + PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)), + })); +} + +void add_copy_packed_dim_offset_node( + ComputeGraph& graph, + const ValueRef in, + const ivec3& range, + const ivec4& src_offset, + const ivec4& dst_offset, + const ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + + // Check the packed dimension is same for both tensors, and if the packed + // dimension is Width or Height. Since the function does not support channel + // packing. + VK_CHECK_COND( + check_same_packed_dim(*t_in, *t_out) && + (check_packed_dim_is(*t_in, WHCN::kWidthDim) || + check_packed_dim_is(*t_in, WHCN::kHeightDim))); + + std::string kernel_name = "copy_packed_dim_offset"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + const auto packed_dim = t_in->packed_dim(); + // A copy of range with the last element set to batch size of the input tensor + ivec4 final_range = { + range[0], range[1], range[2], dim_at(t_in->sizes(), kBatch4D)}; + ivec3 global_wg_size = t_out->logical_limits(); + // The starting offset in a texel where this tensor will start copying to + const auto dst_lane_offset = dst_offset[packed_dim] & 0x3; + // The total packed texels this tensor will be copied to + // The first texel of tensor data in packed dimension will be copied to remain + // lanes from previous write Hence (4 - dst_lane_offset) is added to tensor + // size in packed dimension + const auto dst_packed_size = utils::div_up_4( + (4 - dst_lane_offset) + + dim_at(t_in->sizes(), normalize_to_dim_index(*t_in, packed_dim))); + + // If the starting offset is not 0, and the total packed texels is greater + // than the source texel range + if (dst_lane_offset != 0 && dst_packed_size > final_range[packed_dim]) { + global_wg_size[packed_dim]++; // Increase the global work group size in + // packed dimension + final_range[packed_dim]++; // Increase the range in packed dimension + } + + auto shader = VK_KERNEL_FROM_STR(kernel_name); + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + { + {out, vkapi::MemoryAccessType::WRITE}, + {out, vkapi::MemoryAccessType::READ}, + {in, vkapi::MemoryAccessType::READ}, + }, + // Parameter buffers + {}, + // Specialization Constants + {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, + nullptr, + {}, + { PushConstantDataInfo( - &dst_offset, sizeof(dst_offset), sizeof(utils::ivec4)), + &final_range, sizeof(final_range), sizeof(ivec4)), + PushConstantDataInfo(&src_offset, sizeof(src_offset), sizeof(ivec4)), + PushConstantDataInfo(&dst_offset, sizeof(dst_offset), sizeof(ivec4)), })); } @@ -140,7 +211,7 @@ void add_copy_channel_offset_node( static_cast(global_size[2]), channel_range}; - const utils::ivec4 offset_params = { + const ivec4 offset_params = { dst_offset[0], dst_offset[1], dst_offset[2], dst_channel_offset}; auto shader = VK_KERNEL_FROM_STR(kernel_name); @@ -179,8 +250,11 @@ void add_copy_offset_node( ValueRef dst_offset_ref, ValueRef out) { ivec3 range = utils::make_ivec3(*graph.get_int_list(range_ref)); - ivec3 src_offset = utils::make_ivec3(*graph.get_int_list(src_offset_ref)); - ivec3 dst_offset = utils::make_ivec3(*graph.get_int_list(dst_offset_ref)); + ivec3 src = utils::make_ivec3(*graph.get_int_list(src_offset_ref)); + ivec3 dst = utils::make_ivec3(*graph.get_int_list(dst_offset_ref)); + + ivec4 src_offset = {src[0], src[1], src[2], 0}; + ivec4 dst_offset = {dst[0], dst[1], dst[2], 0}; add_copy_offset_node(graph, in, range, src_offset, dst_offset, out); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.h b/backends/vulkan/runtime/graph/ops/impl/Copy.h index 60bb20eedf0..d4b4c0dcc03 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.h +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.h @@ -17,6 +17,7 @@ namespace vkcompute { // add_copy_offset_node resumes the vkCmdCopyImage command. It copies the // texture extents specified by the range, src_offset, and dst_offset (all are // in texture coordinate (x, y, z) from the input image to the output image. +// src_offset.w and dst_offset.w may contain channel size information. // // It is possible to have input and output to point to the same image // object. But when the source range and destination range overlap, the behavior @@ -25,8 +26,24 @@ void add_copy_offset_node( ComputeGraph& graph, const ValueRef in, const utils::ivec3& range, - const utils::ivec3& src_offset, - const utils::ivec3& dst_offset, + const utils::ivec4& src_offset, + const utils::ivec4& dst_offset, + const ValueRef out); + +// add_copy_packed_dim_offset_node behaves similar to add_copy_node, except that +// its used when copying packed dimension, if tensor is width or height packed. +// src_offset.w and dst_offset.w may contain channel size information. +// +// It copies the texture extents specified by the range, src_offset, and +// dst_offset (all are in texture coordinate (x, y, z) from the input image to +// the output image. +// +void add_copy_packed_dim_offset_node( + ComputeGraph& graph, + const ValueRef in, + const utils::ivec3& range, + const utils::ivec4& src_offset, + const utils::ivec4& dst_offset, const ValueRef out); // add_copy_channel_offset_node behaves similar to add_copy_node, except that it diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index 00199ba7a80..49daabdcb76 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -148,8 +148,8 @@ void add_repeat_node( if (int64_t channel_repeat = dim_at(repeats); channel_repeat == 1) { // If no repeat, short-cut to a direct copy - utils::ivec3 src_offset{0, 0, 0}; - utils::ivec3 dst_offset{0, 0, 0}; + utils::ivec4 src_offset{0, 0, 0, 0}; + utils::ivec4 dst_offset{0, 0, 0, 0}; add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out); @@ -160,10 +160,10 @@ void add_repeat_node( // TODO: refactor width, height, and batch into a common helper function. // Width if (int64_t width_repeat = dim_at(repeats); width_repeat > 1) { - utils::ivec3 src_offset{0, 0, 0}; + utils::ivec4 src_offset{0, 0, 0, 0}; for (int i = 1; i < width_repeat; ++i) { - utils::ivec3 dst_offset{i * dim_at(in_sizes), 0, 0}; + utils::ivec4 dst_offset{i * dim_at(in_sizes), 0, 0, 0}; add_copy_offset_node( graph, out, running_range, src_offset, dst_offset, out); @@ -174,10 +174,10 @@ void add_repeat_node( // Height if (int64_t height_repeat = dim_at(repeats); height_repeat > 1) { - utils::ivec3 src_offset{0, 0, 0}; + utils::ivec4 src_offset{0, 0, 0, 0}; for (int i = 1; i < height_repeat; ++i) { - utils::ivec3 dst_offset = {0, i * dim_at(in_sizes), 0}; + utils::ivec4 dst_offset = {0, i * dim_at(in_sizes), 0, 0}; add_copy_offset_node( graph, out, running_range, src_offset, dst_offset, out); @@ -188,10 +188,10 @@ void add_repeat_node( // Batch if (int64_t batch_repeat = dim_at(repeats); batch_repeat > 1) { - utils::ivec3 src_offset{0, 0, 0}; + utils::ivec4 src_offset{0, 0, 0, 0}; for (int i = 1; i < batch_repeat; ++i) { - utils::ivec3 dst_offset = {0, 0, i * running_range[2]}; + utils::ivec4 dst_offset = {0, 0, i * running_range[2], 0}; add_copy_offset_node( graph, out, running_range, src_offset, dst_offset, out); diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index 39039e51025..ca585f1fb6d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -43,8 +43,8 @@ void add_split_with_sizes_default_node( } if (dim_index == kWidth4D) { - utils::ivec3 src_offset = utils::make_ivec3({0, 0, 0}, false); - utils::ivec3 dst_offset = utils::make_ivec3({0, 0, 0}, false); + utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false); + utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false); for (ValueRef out_ref : *out_list) { // Doesn't need to use split_size since we have already verified that the @@ -56,8 +56,8 @@ void add_split_with_sizes_default_node( src_offset[0] += range[0]; } } else if (dim_index == kHeight4D) { - utils::ivec3 src_offset = utils::make_ivec3({0, 0, 0}, false); - utils::ivec3 dst_offset = utils::make_ivec3({0, 0, 0}, false); + utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false); + utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false); for (ValueRef out_ref : *out_list) { vTensorPtr t_out = graph.get_tensor(out_ref); @@ -67,8 +67,8 @@ void add_split_with_sizes_default_node( src_offset[1] += range[1]; } } else if (dim_index == kBatch4D) { - utils::ivec3 src_offset = utils::make_ivec3({0, 0, 0}, false); - utils::ivec3 dst_offset = utils::make_ivec3({0, 0, 0}, false); + utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false); + utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false); for (ValueRef out_ref : *out_list) { vTensorPtr t_out = graph.get_tensor(out_ref); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 1e59c53fc79..e4f7ac15434 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -850,8 +850,11 @@ def get_cat_inputs(): test_suite = VkTestSuite( [ # Cat on Height + ([(M, M, 3, 5), (M, M, 0, 5)], 2), ([(S1, S1, 3, 5), (S1, S1, 0, 5)], 2), + ([(M, M, 3, 5), (M, M, 4, 5)], 2), ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2), + ([(M2, 3, 5), (M2, 4, 5)], 1), ([(S1, 3, 5), (S1, 4, 5)], 1), ([(3, 5), (4, 5)], 0), ([(3, 5), (4, 5), (1, 5)], 0), @@ -860,7 +863,9 @@ def get_cat_inputs(): 0, ), # Cat on Width + ([(M, M, 5, 3), (M, M, 5, 4)], 3), ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3), + ([(M, 5, 3), (M, 5, 4)], 2), ([(S1, 5, 3), (S1, 5, 4)], 2), ([(5, 0), (5, 4)], 1), ([(5, 3), (5, 4)], 1), @@ -871,7 +876,9 @@ def get_cat_inputs(): ), ([(5,), (6,)], 0), # Cat on Batch + ([(M, S1, 5, 4), (M1, S1, 5, 4)], 0), ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0), + ([(S, M, 5, 4), (S1, M, 5, 4)], 0), ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0), ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0), ( @@ -883,7 +890,9 @@ def get_cat_inputs(): 0, ), # Cat on Channel + ([(M, 5, 4), (0, 5, 4), (M1, 5, 4)], 0), ([(S, 5, 4), (0, 5, 4), (S2, 5, 4)], 0), + ([(M, 5, 4), (M1, 5, 4), (M2, 5, 4)], 0), ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0), ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0), ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1), @@ -899,6 +908,8 @@ def get_cat_inputs(): ] ) test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kHeightPacked", "utils::kChannelsPacked", ] test_suite.data_gen = "make_seq_tensor"