Skip to content

Commit d5a6f9e

Browse files
authored
[ET-VK] Minor performance improvement to permute op.
Differential Revision: D70917659 Pull Request resolved: #9330
1 parent c884c38 commit d5a6f9e

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

backends/vulkan/runtime/graph/ops/glsl/permute.glsl

+23-19
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232
layout(constant_id = 3) const int packed_dim = C_DIM;
3333

34-
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
35-
3634
void main() {
37-
u16vec3 pos = u16vec3(gl_GlobalInvocationID);
35+
ivec3 pos = ivec3(gl_GlobalInvocationID);
3836

3937
if (any(greaterThanEqual(pos, out_limits.xyz))) {
4038
return;
@@ -48,34 +46,40 @@ void main() {
4846
// index of packed dim in bchw format
4947
const int in_packed_dim_bchw_index = 3 - packed_dim;
5048

51-
for (int j = 0; j < 4; ++j, pos[packed_dim]++) {
52-
ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w
53-
// determine input position based on output position and permute map
54-
// out_ndims is in BCHW format
55-
in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x);
56-
in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x);
57-
in_bchw_pos[out_ndims[2]] = pos.y;
58-
in_bchw_pos[out_ndims[3]] = pos.x;
49+
// determine input position based on output position and permute map
50+
// out_ndims is in BCHW format
51+
ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w
52+
in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x);
53+
in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x);
54+
in_bchw_pos[out_ndims[2]] = pos.y;
55+
in_bchw_pos[out_ndims[3]] = pos.x;
5956

57+
for (int j = 0; j < 4; ++j) {
58+
// terminate the loop if trying to access input texture out of bounds
6059
if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) {
6160
break;
6261
}
62+
ivec3 fetch_pos;
6363

64-
// input tensor's packed dim pos (in xyz format) corresponding to output tensor's pos (which is also in xyz format)
65-
const int in_packed_dim_pos = in_bchw_pos[in_packed_dim_bchw_index];
64+
fetch_pos.xy = in_bchw_pos.wz;
65+
// calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
66+
fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
6667

67-
// calculate input position in y axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
68-
in_bchw_pos.y = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
68+
// input tensor's packed dim lane corresponding to output tensor's pos
69+
const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3;
6970

7071
// scale down input tensor's packed dim pos to perform fetch
71-
in_bchw_pos[in_packed_dim_bchw_index] >>= 2;
72+
fetch_pos[packed_dim] >>= 2;
7273

7374
// fetch input texel
74-
VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(in_bchw_pos.wzy), 0));
75-
outval[j] = inval[in_packed_dim_pos & 0x3];
75+
VEC4_T inval = VEC4_T(texelFetch(image_in, fetch_pos, 0));
76+
outval[j] = inval[in_packed_dim_lane_index];
77+
78+
// go to next position in the input, that is mapped to the packed dim in the output
79+
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++;
7680
}
7781

78-
pos[packed_dim] = uint16_t(gl_GlobalInvocationID[packed_dim]);
82+
pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]);
7983

8084
imageStore(image_out, pos, outval);
8185
}

0 commit comments

Comments
 (0)