Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Minor performance improvement to permute op. #9330

Merged
merged 3 commits into from
Mar 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions backends/vulkan/runtime/graph/ops/glsl/permute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
layout(constant_id = 3) const int packed_dim = C_DIM;

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

void main() {
u16vec3 pos = u16vec3(gl_GlobalInvocationID);
ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_limits.xyz))) {
return;
Expand All @@ -48,34 +46,40 @@ void main() {
// index of packed dim in bchw format
const int in_packed_dim_bchw_index = 3 - packed_dim;

for (int j = 0; j < 4; ++j, pos[packed_dim]++) {
ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w
// determine input position based on output position and permute map
// out_ndims is in BCHW format
in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x);
in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x);
in_bchw_pos[out_ndims[2]] = pos.y;
in_bchw_pos[out_ndims[3]] = pos.x;
// determine input position based on output position and permute map
// out_ndims is in BCHW format
ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w
in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x);
in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x);
in_bchw_pos[out_ndims[2]] = pos.y;
in_bchw_pos[out_ndims[3]] = pos.x;

for (int j = 0; j < 4; ++j) {
// terminate the loop if trying to access input texture out of bounds
if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) {
break;
}
ivec3 fetch_pos;

// input tensor's packed dim pos (in xyz format) corresponding to output tensor's pos (which is also in xyz format)
const int in_packed_dim_pos = in_bchw_pos[in_packed_dim_bchw_index];
fetch_pos.xy = in_bchw_pos.wz;
// calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;

// calculate input position in y axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
in_bchw_pos.y = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
// input tensor's packed dim lane corresponding to output tensor's pos
const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3;

// scale down input tensor's packed dim pos to perform fetch
in_bchw_pos[in_packed_dim_bchw_index] >>= 2;
fetch_pos[packed_dim] >>= 2;

// fetch input texel
VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(in_bchw_pos.wzy), 0));
outval[j] = inval[in_packed_dim_pos & 0x3];
VEC4_T inval = VEC4_T(texelFetch(image_in, fetch_pos, 0));
outval[j] = inval[in_packed_dim_lane_index];

// go to next position in the input, that is mapped to the packed dim in the output
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++;
}

pos[packed_dim] = uint16_t(gl_GlobalInvocationID[packed_dim]);
pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]);

imageStore(image_out, pos, outval);
}
Loading