@@ -22,65 +22,68 @@ void add_cat_default_node(
22
22
ValueRef dim_ref,
23
23
ValueRef out) {
24
24
ValueListPtr input_list = graph.get_value_list (in_list_ref);
25
-
26
- for (ValueRef input_ref : *input_list) {
27
- vTensorPtr t_in = graph.get_tensor (input_ref);
28
- VK_CHECK_COND (check_packed_dim_is (*t_in, WHCN::kChannelsDim ));
29
- }
30
-
31
25
int64_t dim = graph.extract_scalar <int64_t >(dim_ref);
32
26
vTensorPtr t_out = graph.get_tensor (out);
33
27
28
+ const auto packed_dim = t_out->packed_dim ();
29
+ const auto packed_dim_index = static_cast <DimIndex>(kWidth4D - packed_dim);
30
+
34
31
DimIndex dim_index = normalize_to_dim_index (*t_out, dim);
32
+ // Index of dimension to be concatenated in (w, h, c * b) coordinate system
33
+ const auto dim_xyz_index = std::min (2 , -dim_index - 1 );
35
34
36
- // TODO: Find ways to factor out the similar code for width, height, and batch
37
- if (dim_index == kWidth4D ) {
38
- utils::ivec3 src_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
39
- utils::ivec3 dst_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
35
+ if (dim_index > kWidth4D || dim_index < kBatch4D ) {
36
+ VK_THROW (" Unexpected value of dim_index=" , dim_index);
37
+ }
40
38
41
- for (ValueRef input_ref : *input_list) {
42
- vTensorPtr t_in = graph.get_tensor (input_ref);
43
- utils::ivec3 range = t_in->logical_limits ();
44
- add_copy_offset_node (
45
- graph, input_ref, range, src_offset, dst_offset, out);
46
- dst_offset[0 ] += range[0 ];
47
- }
39
+ utils::ivec4 src_offset = utils::make_ivec4 ({0 , 0 , 0 , 0 }, false );
40
+ utils::ivec4 dst_offset = utils::make_ivec4 ({0 , 0 , 0 , 0 }, false );
48
41
49
- } else if (dim_index == kHeight4D ) {
50
- utils::ivec3 src_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
51
- utils::ivec3 dst_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
42
+ const bool is_concat_channel = (dim_index == kChannel4D );
52
43
53
- for (ValueRef input_ref : *input_list) {
54
- vTensorPtr t_in = graph.get_tensor (input_ref);
55
- utils::ivec3 range = t_in->logical_limits ();
56
- add_copy_offset_node (
57
- graph, input_ref, range, src_offset, dst_offset, out);
58
- dst_offset[1 ] += range[1 ];
59
- }
60
- } else if (dim_index == kBatch4D ) {
61
- utils::ivec3 src_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
62
- utils::ivec3 dst_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
44
+ // if concatenating channels
45
+ if (is_concat_channel) {
46
+ // set destination offset w as channel size of the output tensor
47
+ dst_offset[3 ] = dim_at (t_out->sizes (), kChannel4D );
48
+ }
63
49
64
- for (ValueRef input_ref : *input_list) {
65
- vTensorPtr t_in = graph.get_tensor (input_ref);
66
- utils::ivec3 range = t_in->logical_limits ();
50
+ for (ValueRef input_ref : *input_list) {
51
+ const vTensorPtr t_in = graph.get_tensor (input_ref);
52
+ const utils::ivec3 range = t_in->logical_limits ();
53
+ const auto in_channel_size = dim_at (t_in->sizes (), kChannel4D );
54
+ // if concatenating same dimension as the packed dimension
55
+ if (dim_index == packed_dim_index) {
56
+ // if concatenating channels, use add_copy_channel_offset_node function as
57
+ // add_copy_packed_dim_offset_node does not support channel packing
58
+ if (is_concat_channel) {
59
+ add_copy_channel_offset_node (
60
+ graph,
61
+ input_ref,
62
+ in_channel_size,
63
+ src_offset[2 ],
64
+ dst_offset[2 ],
65
+ out);
66
+ dst_offset[dim_xyz_index] += in_channel_size;
67
+ } else {
68
+ // src_offset[3] is not used now but will be used in the future when
69
+ // add_copy_packed_dim_offset_node will support channel packing
70
+ //
71
+ // set source offset w as channel size of the output tensor if
72
+ // concatenating channels
73
+ src_offset[3 ] = is_concat_channel ? in_channel_size : 0 ;
74
+ add_copy_packed_dim_offset_node (
75
+ graph, input_ref, range, src_offset, dst_offset, out);
76
+ dst_offset[dim_xyz_index] += dim_at (t_in->sizes (), packed_dim_index);
77
+ }
78
+ } else {
79
+ // set source offset w as channel size of the output tensor if
80
+ // concatenating channels
81
+ src_offset[3 ] = is_concat_channel ? in_channel_size : 0 ;
67
82
add_copy_offset_node (
68
83
graph, input_ref, range, src_offset, dst_offset, out);
69
- dst_offset[2 ] += range[2 ];
84
+ dst_offset[dim_xyz_index] +=
85
+ is_concat_channel ? in_channel_size : range[dim_xyz_index];
70
86
}
71
- } else if (dim_index == kChannel4D ) {
72
- int32_t src_offset = 0 ;
73
- int32_t dst_offset = 0 ;
74
-
75
- for (ValueRef input_ref : *input_list) {
76
- vTensorPtr t_in = graph.get_tensor (input_ref);
77
- int32_t range = dim_at (t_in->sizes (), kChannel4D );
78
- add_copy_channel_offset_node (
79
- graph, input_ref, range, src_offset, dst_offset, out);
80
- dst_offset += range;
81
- }
82
- } else {
83
- VK_THROW (" Unexpected value of dim_index=" , dim_index);
84
87
}
85
88
}
86
89
0 commit comments