diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl
index 8cbf5db294b..d4ad736a563 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl
+++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl
@@ -31,10 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
 layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
 layout(constant_id = 3) const int packed_dim = C_DIM;
 
-#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-
 void main() {
-  u16vec3 pos = u16vec3(gl_GlobalInvocationID);
+  ivec3 pos = ivec3(gl_GlobalInvocationID);
 
   if (any(greaterThanEqual(pos, out_limits.xyz))) {
     return;
@@ -48,34 +46,40 @@ void main() {
   // index of packed dim in bchw format
   const int in_packed_dim_bchw_index = 3 - packed_dim;
 
-  for (int j = 0; j < 4; ++j, pos[packed_dim]++) {
-    ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w
-    // determine input position based on output position and permute map
-    // out_ndims is in BCHW format
-    in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x);
-    in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x);
-    in_bchw_pos[out_ndims[2]] = pos.y;
-    in_bchw_pos[out_ndims[3]] = pos.x;
+  // determine input position based on output position and permute map
+  // out_ndims is in BCHW format
+  ivec4 in_bchw_pos = ivec4(0); // holds b,c,h,w
+  in_bchw_pos[out_ndims[0]] = (pos.z / channel_info.x);
+  in_bchw_pos[out_ndims[1]] = (pos.z % channel_info.x);
+  in_bchw_pos[out_ndims[2]] = pos.y;
+  in_bchw_pos[out_ndims[3]] = pos.x;
 
+  for (int j = 0; j < 4; ++j) {
+    // terminate the loop if trying to access input texture out of bounds
     if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) {
       break;
     }
+    ivec3 fetch_pos;
 
-    // input tensor's packed dim pos (in xyz format) corresponding to output tensor's pos (which is also in xyz format)
-    const int in_packed_dim_pos = in_bchw_pos[in_packed_dim_bchw_index];
+    fetch_pos.xy = in_bchw_pos.wz;
+    // calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
+    fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
 
-    // calculate input position in y axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
-    in_bchw_pos.y = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
+    // input tensor's packed dim lane corresponding to output tensor's pos
+    const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3;
 
     // scale down input tensor's packed dim pos to perform fetch
-    in_bchw_pos[in_packed_dim_bchw_index] >>= 2;
+    fetch_pos[packed_dim] >>= 2;
 
     // fetch input texel
-    VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(in_bchw_pos.wzy), 0));
-    outval[j] = inval[in_packed_dim_pos & 0x3];
+    VEC4_T inval = VEC4_T(texelFetch(image_in, fetch_pos, 0));
+    outval[j] = inval[in_packed_dim_lane_index];
+
+    // go to next position in the input, that is mapped to the packed dim in the output
+    in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++;
   }
 
-  pos[packed_dim] = uint16_t(gl_GlobalInvocationID[packed_dim]);
+  pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]);
 
   imageStore(image_out, pos, outval);
 }