@@ -31,10 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
31
31
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
32
32
layout (constant_id = 3 ) const int packed_dim = C_DIM;
33
33
34
- #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
35
-
36
34
void main() {
37
- u16vec3 pos = u16vec3 (gl_GlobalInvocationID);
35
+ ivec3 pos = ivec3 (gl_GlobalInvocationID);
38
36
39
37
if (any (greaterThanEqual (pos, out_limits.xyz))) {
40
38
return ;
@@ -48,34 +46,40 @@ void main() {
48
46
// index of packed dim in bchw format
49
47
const int in_packed_dim_bchw_index = 3 - packed_dim;
50
48
51
- for (int j = 0 ; j < 4 ; ++ j, pos[packed_dim]++ ) {
52
- ivec4 in_bchw_pos = ivec4 (0 ); // holds b,c,h,w
53
- // determine input position based on output position and permute map
54
- // out_ndims is in BCHW format
55
- in_bchw_pos[out_ndims[0 ]] = (pos.z / channel_info.x);
56
- in_bchw_pos[out_ndims[1 ]] = (pos.z % channel_info.x);
57
- in_bchw_pos[out_ndims[2 ]] = pos.y;
58
- in_bchw_pos[out_ndims[3 ]] = pos.x;
49
+ // determine input position based on output position and permute map
50
+ // out_ndims is in BCHW format
51
+ ivec4 in_bchw_pos = ivec4 (0 ); // holds b,c,h,w
52
+ in_bchw_pos[out_ndims[0 ]] = (pos.z / channel_info.x);
53
+ in_bchw_pos[out_ndims[1 ]] = (pos.z % channel_info.x);
54
+ in_bchw_pos[out_ndims[2 ]] = pos.y;
55
+ in_bchw_pos[out_ndims[3 ]] = pos.x;
59
56
57
+ for (int j = 0 ; j < 4 ; ++ j) {
58
+ // terminate the loop if trying to access input texture out of bounds
60
59
if (any (greaterThanEqual (in_bchw_pos.wzyx, in_sizes.xyzw))) {
61
60
break ;
62
61
}
62
+ ivec3 fetch_pos;
63
63
64
- // input tensor's packed dim pos (in xyz format) corresponding to output tensor's pos (which is also in xyz format)
65
- const int in_packed_dim_pos = in_bchw_pos[in_packed_dim_bchw_index];
64
+ fetch_pos.xy = in_bchw_pos.wz;
65
+ // calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
66
+ fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
66
67
67
- // calculate input position in y axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
68
- in_bchw_pos.y = in_bchw_pos.y + in_bchw_pos.x * channel_info.y ;
68
+ // input tensor's packed dim lane corresponding to output tensor's pos
69
+ const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3 ;
69
70
70
71
// scale down input tensor's packed dim pos to perform fetch
71
- in_bchw_pos[in_packed_dim_bchw_index ] >>= 2 ;
72
+ fetch_pos[packed_dim ] >>= 2 ;
72
73
73
74
// fetch input texel
74
- VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(in_bchw_pos.wzy), 0 ));
75
- outval[j] = inval[in_packed_dim_pos & 0x3];
75
+ VEC4_T inval = VEC4_T(texelFetch(image_in, fetch_pos, 0 ));
76
+ outval[j] = inval[in_packed_dim_lane_index];
77
+
78
+ // go to next position in the input, that is mapped to the packed dim in the output
79
+ in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++ ;
76
80
}
77
81
78
- pos[packed_dim] = uint16_t (gl_GlobalInvocationID[packed_dim]);
82
+ pos[packed_dim] = int (gl_GlobalInvocationID[packed_dim]);
79
83
80
84
imageStore(image_out, pos, outval);
81
85
}
0 commit comments