From 90290211fe35e3f9488212fbb7702ef5c461eec6 Mon Sep 17 00:00:00 2001 From: riccardo Date: Sun, 12 Apr 2020 12:23:59 +0700 Subject: [PATCH] Add support for Upsample op in CAFFE Add resize nearest neighbor output shape computation Update dockerfile and caffe patch Update resize_nearest_neighbor implementation and conversion Fix bugs and update GPU implementation Align resize_nearest_neighbor with resize_bilinear code Use scale value Fix typo Align kernel compute arguments Remove comments Align caffe_converter Remove unused code --- .../opencl/image/resize_nearest_neighbor.cc | 15 +- .../opencl/image/resize_nearest_neighbor.h | 9 +- mace/ops/opencl/resize_nearest_neighbor.h | 2 - mace/ops/resize_nearest_neighbor.cc | 35 +- third_party/caffe/Dockerfile | 9 +- third_party/caffe/caffe.proto | 5 + third_party/caffe/upsample.patch | 347 ++++++++++++++++++ tools/python/transform/caffe_converter.py | 15 + tools/python/transform/shape_inference.py | 15 + 9 files changed, 415 insertions(+), 37 deletions(-) create mode 100644 third_party/caffe/upsample.patch diff --git a/mace/ops/opencl/image/resize_nearest_neighbor.cc b/mace/ops/opencl/image/resize_nearest_neighbor.cc index 9f9dd1c8d..b85048ca9 100644 --- a/mace/ops/opencl/image/resize_nearest_neighbor.cc +++ b/mace/ops/opencl/image/resize_nearest_neighbor.cc @@ -24,24 +24,15 @@ namespace image { MaceStatus ResizeNearestNeighborKernel::Compute( OpContext *context, const Tensor *input, - const Tensor *size, - const std::vector &dims, Tensor *output) { const index_t batch = input->dim(0); const index_t in_height = input->dim(1); const index_t in_width = input->dim(2); const index_t channels = input->dim(3); - index_t out_height = 0; - index_t out_width = 0; - if (dims.size() < 2) { - Tensor::MappingGuard size_mapper(size); - out_height = size->data()[0]; - out_width = size->data()[1]; - } else { - out_height = dims[0]; - out_width = dims[1]; - } + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t out_height = in_height*scale_; + const index_t out_width = in_width*scale_; const uint32_t gws[3] = {static_cast(channel_blocks), static_cast(out_width), diff --git a/mace/ops/opencl/image/resize_nearest_neighbor.h b/mace/ops/opencl/image/resize_nearest_neighbor.h index 9e2cec61a..601788f86 100644 --- a/mace/ops/opencl/image/resize_nearest_neighbor.h +++ b/mace/ops/opencl/image/resize_nearest_neighbor.h @@ -66,18 +66,19 @@ inline std::vector LocalWS(OpenCLRuntime *runtime, class ResizeNearestNeighborKernel : public OpenCLResizeNearestNeighborKernel { public: - explicit ResizeNearestNeighborKernel(bool align_corners) - : align_corners_(align_corners) {} + ResizeNearestNeighborKernel(bool align_corners, + const index_t scale) + : align_corners_(align_corners), + scale_(scale) {} MaceStatus Compute( OpContext *context, const Tensor *input, - const Tensor *size, - const std::vector &dims, Tensor *output) override; private: bool align_corners_; + index_t scale_; cl::Kernel kernel_; uint32_t kwg_size_; std::vector input_shape_; diff --git a/mace/ops/opencl/resize_nearest_neighbor.h b/mace/ops/opencl/resize_nearest_neighbor.h index c98fc955e..e352ea1cb 100644 --- a/mace/ops/opencl/resize_nearest_neighbor.h +++ b/mace/ops/opencl/resize_nearest_neighbor.h @@ -32,8 +32,6 @@ class OpenCLResizeNearestNeighborKernel { virtual MaceStatus Compute( OpContext *context, const Tensor *input, - const Tensor *size, - const std::vector &dims, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLResizeNearestNeighborKernel); }; diff --git a/mace/ops/resize_nearest_neighbor.cc b/mace/ops/resize_nearest_neighbor.cc index 6ac6b9e71..2313cdbaf 100644 --- a/mace/ops/resize_nearest_neighbor.cc +++ b/mace/ops/resize_nearest_neighbor.cc @@ -77,28 +77,27 @@ class ResizeNearestNeighborOp : public Operation { public: explicit ResizeNearestNeighborOp(OpConstructContext *context) : Operation(context), - align_corners_(Operation::GetOptionalArg("align_corners", - false)) {} + align_corners_(Operation::GetOptionalArg("align_corners", false)), + size_(Operation::GetRepeatedArgs("size", {-1})) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); const Tensor *input = this->Input(0); - const Tensor *size = this->Input(1); - Tensor::MappingGuard size_mapper(size); Tensor *output = this->Output(0); - MACE_CHECK(input->dim_size() == 4 && size->dim_size() == 1, - "input must be 4-dimensional and size must be 1-dimensional. ", - input->dim_size(), size->dim_size()); + MACE_CHECK(input->dim_size() == 4, + "input must be 4-dimensional. ", + input->dim_size()); const index_t batch = input->dim(0); const index_t channels = input->dim(1); const index_t in_height = input->dim(2); const index_t in_width = input->dim(3); - const index_t out_height = size->data()[0]; - const index_t out_width = size->data()[1]; - MACE_CHECK(out_height > 0 && out_width > 0, out_height, out_width); + index_t scale = size_[0]; + MACE_CHECK(scale > 0); + const index_t out_height = in_height*scale; + const index_t out_width = in_width*scale; std::vector out_shape{batch, channels, out_height, out_width}; MACE_RETURN_IF_ERROR(output->Resize(out_shape)); Tensor::MappingGuard input_mapper(input); @@ -138,6 +137,7 @@ class ResizeNearestNeighborOp : public Operation { private: bool align_corners_; + std::vector size_; }; #ifdef MACE_ENABLE_OPENCL @@ -145,29 +145,30 @@ template<> class ResizeNearestNeighborOp : public Operation { public: explicit ResizeNearestNeighborOp(OpConstructContext *context) - : Operation(context), dim_(Operation::GetRepeatedArgs("dim")) { + : Operation(context) { bool align_corners = Operation::GetOptionalArg( "align_corners", false); + std::vector size = Operation::GetRepeatedArgs( + "size", {-1}); + MACE_CHECK(size.size() == 1); if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) { kernel_ = make_unique( - align_corners); + align_corners, size[0]); } else { MACE_NOT_IMPLEMENTED; } } MaceStatus Run(OpContext *context) override { const Tensor *input = this->Input(0); - const Tensor *size = this->Input(1); Tensor *output = this->Output(0); - MACE_CHECK(input->dim_size() == 4 && size->dim_size() == 1, + MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional and size must be 1-dimensional.", - input->dim_size(), size->dim_size()); + input->dim_size()); - return kernel_->Compute(context, input, size, dim_, output); + return kernel_->Compute(context, input, output); } private: - std::vector dim_; std::unique_ptr kernel_; }; #endif // MACE_ENABLE_OPENCL diff --git a/third_party/caffe/Dockerfile b/third_party/caffe/Dockerfile index 92f67f231..771a45715 100644 --- a/third_party/caffe/Dockerfile +++ b/third_party/caffe/Dockerfile @@ -35,8 +35,13 @@ ENV CLONE_TAG=1.0 # https://github.com/pypa/pip/issues/5599 RUN git clone -b ${CLONE_TAG} --depth 1 https://github.com/BVLC/caffe.git . && \ python -m pip install --upgrade pip && \ - cd python && for req in $(cat requirements.txt) pydot; do pip install $req; done && cd .. && \ - mkdir build && cd build && \ + cd python && for req in $(cat requirements.txt) pydot; do pip install $req; done && cd .. + +COPY upsample.patch . + +RUN git apply upsample.patch + +RUN mkdir build && cd build && \ cmake -DCPU_ONLY=1 .. && \ make -j"$(nproc)" diff --git a/third_party/caffe/caffe.proto b/third_party/caffe/caffe.proto index 8be551bdb..4ce00334e 100644 --- a/third_party/caffe/caffe.proto +++ b/third_party/caffe/caffe.proto @@ -541,6 +541,7 @@ message LayerParameter { optional TanHParameter tanh_param = 127; optional ThresholdParameter threshold_param = 128; optional TileParameter tile_param = 138; + optional UpsampleParameter upsample_param = 149; optional VideoDataParameter video_data_param = 207; optional WindowDataParameter window_data_param = 129; optional ShuffleChannelParameter shuffle_channel_param = 164; @@ -1939,3 +1940,7 @@ message ShuffleChannelParameter { message L2NormalizationParameter { optional int32 axis = 1 [default = 1]; } + +message UpsampleParameter { + optional int32 scale = 1 [default = 1]; +} diff --git a/third_party/caffe/upsample.patch b/third_party/caffe/upsample.patch new file mode 100644 index 000000000..d0d18d237 --- /dev/null +++ b/third_party/caffe/upsample.patch @@ -0,0 +1,347 @@ +diff --git a/include/caffe/layers/upsample_layer.hpp b/include/caffe/layers/upsample_layer.hpp +new file mode 100644 +index 0000000..1ef6044 +--- /dev/null ++++ b/include/caffe/layers/upsample_layer.hpp +@@ -0,0 +1,44 @@ ++#ifndef CAFFE_UPSAMPLE_LAYER_HPP_ ++#define CAFFE_UPSAMPLE_LAYER_HPP_ ++ ++#include ++ ++#include "caffe/blob.hpp" ++#include "caffe/layer.hpp" ++#include "caffe/proto/caffe.pb.h" ++ ++namespace caffe { ++ ++template ++class UpsampleLayer : public Layer { ++ public: ++ explicit UpsampleLayer(const LayerParameter& param) ++ : Layer(param) {} ++ virtual void LayerSetUp(const vector*>& bottom, ++ const vector*>& top); ++ virtual void Reshape(const vector*>& bottom, ++ const vector*>& top); ++ ++ virtual inline const char* type() const { return "Upsample"; } ++ virtual inline int MinBottomBlobs() const { return 1; } ++ virtual inline int MaxBottomBlobs() const { return 1; } ++ virtual inline int ExactNumTopBlobs() const { return 1; } ++ ++ virtual void Forward_cpu(const vector*>& bottom, ++ const vector*>& top); ++ virtual void Forward_gpu(const vector*>& bottom, ++ const vector*>& top); ++ virtual void Backward_cpu(const vector*>& top, ++ const vector& propagate_down, const vector*>& bottom); ++ virtual void Backward_gpu(const vector*>& top, ++ const vector& propagate_down, const vector*>& bottom); ++ ++ private: ++ int scale_; ++}; ++ ++ ++ ++} // namespace caffe ++ ++#endif // CAFFE_UPSAMPLE_LAYER_HPP_ +diff --git a/src/caffe/layers/upsample_layer.cpp b/src/caffe/layers/upsample_layer.cpp +new file mode 100644 +index 0000000..46b2ed9 +--- /dev/null ++++ b/src/caffe/layers/upsample_layer.cpp +@@ -0,0 +1,89 @@ ++#include ++#include "caffe/layers/upsample_layer.hpp" ++ ++namespace caffe { ++ ++template ++void UpsampleLayer::LayerSetUp( ++ const vector*>& bottom, const vector*>& top) { ++ UpsampleParameter upsample_param = this->layer_param_.upsample_param(); ++ scale_ = upsample_param.scale(); ++} ++ ++template ++void UpsampleLayer::Reshape( ++ const vector*>& bottom, const vector*>& top) { ++ vector out_shape; ++ for (int i = 0; i < bottom[0]->num_axes(); i++) { ++ out_shape.push_back(bottom[0]->shape(i)); ++ } ++ ++ out_shape[bottom[0]->num_axes() - 1] *= scale_; ++ out_shape[bottom[0]->num_axes() - 2] *= scale_; ++ top[0]->Reshape(out_shape); ++} ++ ++template ++void UpsampleLayer::Forward_cpu(const vector*>& bottom, ++ const vector*>& top) { ++ ++ int N = top[0]->shape(0); ++ int C = top[0]->shape(1); ++ int H = top[0]->shape(2); ++ int W = top[0]->shape(3); ++ ++ const Dtype *input = bottom[0]->cpu_data(); ++ Dtype *output = top[0]->mutable_cpu_data(); ++ for (int n = 0; n < N; n++) { ++ for (int c = 0; c < C; c++) { ++ for (int h = 0; h < H; h++) { ++ for (int w = 0; w < W; w++) { ++ int nw = w/scale_; ++ int nh = h/scale_; ++ int out_idx = (((n * C + c) * H) + h) * W + w; ++ int in_idx = (((n * C + c) * (H / scale_)) + nh) * (W / scale_) + nw; ++ output[out_idx] = input[in_idx]; ++ } ++ } ++ } ++ } ++} ++ ++template ++void UpsampleLayer::Backward_cpu(const vector*>& top, ++ const vector& propagate_down, const vector*>& bottom) { ++ int N = bottom[0]->shape(0); ++ int C = bottom[0]->shape(1); ++ int H = bottom[0]->shape(2); ++ int W = bottom[0]->shape(3); ++ const Dtype *output_grad = top[0]->cpu_diff(); ++ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); ++ caffe_set(bottom[0]->count(), Dtype(0), bottom_diff); ++ for (int n = 0; n < N; n++) { ++ for (int c = 0; c < C; c++) { ++ for (int h = 0; h < H; h++) { ++ for (int w = 0; w < W; w++) { ++ for (int i = 0; i < scale_; i++) { ++ for (int j = 0; j < scale_; j++) { ++ int nw = w * scale_ + i; ++ int nh = h * scale_ + j; ++ int out_idx = (((n * C + c) * H) + h) * W + w; ++ int in_idx = (((n * C + c) * (H * scale_)) ++ + nh) * (W * scale_) + nw; ++ bottom_diff[out_idx] += output_grad[in_idx]; ++ } ++ } ++ } ++ } ++ } ++ } ++} ++ ++#ifdef CPU_ONLY ++STUB_GPU(UpsampleLayer); ++#endif ++ ++INSTANTIATE_CLASS(UpsampleLayer); ++REGISTER_LAYER_CLASS(Upsample); ++ ++} // namespace caffe +diff --git a/src/caffe/layers/upsample_layer.cu b/src/caffe/layers/upsample_layer.cu +new file mode 100644 +index 0000000..aade4b3 +--- /dev/null ++++ b/src/caffe/layers/upsample_layer.cu +@@ -0,0 +1,101 @@ ++#include ++ ++#include "caffe/filler.hpp" ++#include "caffe/layers/upsample_layer.hpp" ++#include "caffe/util/math_functions.hpp" ++ ++namespace caffe { ++ ++__device__ int translate_idx(int ii, int d1, int d2, int d3, int scale_factor) { ++ int x, y, z, w; ++ w = ii % d3; ++ ii = ii/d3; ++ z = ii % d2; ++ ii = ii/d2; ++ y = ii % d1; ++ ii = ii/d1; ++ x = ii; ++ w = w/scale_factor; ++ z = z/scale_factor; ++ d2 /= scale_factor; ++ d3 /= scale_factor; ++ return (((x*d1+y)*d2)+z)*d3+w; ++} ++ ++__device__ int translate_idx_inv( ++ int ii, int d1, int d2, int d3, int scale_factor, int off_x, int off_y) { ++ int x, y, z, w; ++ w = ii % d3; ++ ii = ii/d3; ++ z = ii % d2; ++ ii = ii/d2; ++ y = ii % d1; ++ ii = ii/d1; ++ x = ii; ++ w = w*scale_factor+off_x; ++ z = z*scale_factor+off_y; ++ d2 *= scale_factor; ++ d3 *= scale_factor; ++ return (((x*d1+y)*d2)+z)*d3+w; ++} ++ ++template ++__global__ void upscale(const Dtype *input, Dtype *output, ++ int no_elements, int scale_factor, int d1, int d2, int d3) { ++ int ii = threadIdx.x + blockDim.x * blockIdx.x; ++ if (ii >= no_elements) return; ++ int ipidx = translate_idx(ii, d1, d2, d3, scale_factor); ++ output[ii]=input[ipidx]; ++} ++ ++template ++__global__ void downscale(Dtype *gradInput_data, const Dtype *gradOutput_data, ++ int no_elements, int scale_factor, int d1, int d2, ++ int d3) { ++ int ii = threadIdx.x + blockDim.x * blockIdx.x; ++ if (ii >= no_elements) return; ++ for (int i = 0; i < scale_factor; i++) { ++ for (int j = 0; j < scale_factor; j++) { ++ int ipidx = translate_idx_inv(ii, d1, d2, d3, scale_factor, i, j); ++ gradInput_data[ii] += gradOutput_data[ipidx]; ++ } ++ } ++} ++ ++ ++ ++template ++void UpsampleLayer::Forward_gpu(const vector*>& bottom, ++ const vector*>& top) { ++ int d1, d2, d3; ++ ++ d1 = top[0]->shape(1); ++ d2 = top[0]->shape(2); ++ d3 = top[0]->shape(3); ++ ++ int no_elements = top[0]->count(); ++ ++ upscale // NOLINT_NEXT_LINE(whitespace/operators) ++ <<>>( ++ bottom[0]->gpu_data(), ++ top[0]->mutable_gpu_data(), no_elements, scale_, d1, d2, d3); ++} ++ ++template ++void UpsampleLayer::Backward_gpu(const vector*>& top, ++ const vector& propagate_down, const vector*>& bottom) { ++ int d1, d2, d3; ++ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); ++ d1 = bottom[0]->shape(1); ++ d2 = bottom[0]->shape(2); ++ d3 = bottom[0]->shape(3); ++ int no_elements = bottom[0]->count(); ++ caffe_gpu_set(bottom[0]->count(), Dtype(0), bottom_diff); ++ downscale // NOLINT_NEXT_LINE(whitespace/operators) ++ <<>>( ++ bottom_diff, top[0]->gpu_diff(), no_elements, scale_, d1, d2, d3); ++} ++ ++INSTANTIATE_LAYER_GPU_FUNCS(UpsampleLayer); ++ ++} // namespace caffe +diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto +index 3dcad69..6c587e8 100644 +--- a/src/caffe/proto/caffe.proto ++++ b/src/caffe/proto/caffe.proto +@@ -322,7 +322,7 @@ message ParamSpec { + // NOTE + // Update the next available ID when you add a new LayerParameter field. + // +-// LayerParameter next available layer-specific ID: 149 (last added: clip_param) ++// LayerParameter next available layer-specific ID: 150 (last added: upsample_param) + message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type +@@ -421,6 +421,7 @@ message LayerParameter { + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; ++ optional UpsampleParameter upsample_param = 149; + } + + // Message that stores parameters used to apply transformation +@@ -1447,3 +1448,7 @@ message PReLUParameter { + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; + } ++ ++message UpsampleParameter { ++ optional int32 scale = 1 [default = 1]; ++} +diff --git a/src/caffe/test/test_upsample_layer.cpp b/src/caffe/test/test_upsample_layer.cpp +new file mode 100644 +index 0000000..e114db9 +--- /dev/null ++++ b/src/caffe/test/test_upsample_layer.cpp +@@ -0,0 +1,60 @@ ++#include ++#include ++ ++#include "boost/scoped_ptr.hpp" ++#include "gtest/gtest.h" ++ ++#include "caffe/blob.hpp" ++#include "caffe/filler.hpp" ++#include "caffe/layers/upsample_layer.hpp" ++#include "caffe/util/io.hpp" ++ ++#include "caffe/test/test_caffe_main.hpp" ++#include "caffe/test/test_gradient_check_util.hpp" ++ ++using boost::scoped_ptr; ++ ++namespace caffe { ++ ++template ++class UpsampleLayerTest : public MultiDeviceTest { ++ typedef typename TypeParam::Dtype Dtype; ++ ++ protected: ++ UpsampleLayerTest() ++ : blob_bottom_data_(new Blob(2, 5, 2, 2)), ++ blob_top_data_(new Blob()) { ++ // fill the values ++ FillerParameter filler_param; ++ filler_param.set_std(10); ++ GaussianFiller filler(filler_param); ++ filler.Fill(this->blob_bottom_data_); ++ blob_bottom_vec_.push_back(blob_bottom_data_); ++ blob_top_vec_.push_back(blob_top_data_); ++ } ++ ++ virtual ~UpsampleLayerTest() { ++ delete blob_bottom_data_; ++ delete blob_top_data_; ++ } ++ Blob* const blob_bottom_data_; ++ Blob* const blob_top_data_; ++ vector*> blob_bottom_vec_; ++ vector*> blob_top_vec_; ++}; ++ ++TYPED_TEST_CASE(UpsampleLayerTest, TestDtypesAndDevices); ++ ++TYPED_TEST(UpsampleLayerTest, TestGradient) { ++ typedef typename TypeParam::Dtype Dtype; ++ LayerParameter layer_param; ++ UpsampleParameter* upsample_param = ++ layer_param.mutable_upsample_param(); ++ upsample_param->set_scale(2); ++ UpsampleLayer layer(layer_param); ++ GradientChecker checker(1e-2, 1e-2, 1701); ++ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, ++ this->blob_top_vec_, 0); ++} ++ ++} // namespace caffe diff --git a/tools/python/transform/caffe_converter.py b/tools/python/transform/caffe_converter.py index 3eb6e4293..921402a70 100644 --- a/tools/python/transform/caffe_converter.py +++ b/tools/python/transform/caffe_converter.py @@ -195,6 +195,7 @@ def __init__(self, option, src_model_file, src_weight_file): 'L1Normalization': self.convert_lpnorm, 'MVN': self.convert_MVN, 'Bias': self.convert_Bias, + 'Upsample': self.convert_resize_nearest_neighbor, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -820,6 +821,20 @@ def convert_reshape(self, caffe_op): if param.HasField('num_axes'): num_axes_arg.i = param.num_axes + def convert_resize_nearest_neighbor(self, caffe_op): + op = self.convert_general_op(caffe_op) + param = caffe_op.layer.upsample_param + op.type = MaceOp.ResizeNearestNeighbor.name + + size_arg = op.arg.add() + size_arg.name = MaceKeyword.mace_resize_size_str + size_value = [int(param.scale)] + size_arg.ints.extend(size_value) + + align_corners_arg = op.arg.add() + align_corners_arg.name = MaceKeyword.mace_align_corners_str + align_corners_arg.i = 1 + def convert_lpnorm(self, caffe_op): op = self.convert_general_op(caffe_op) param = caffe_op.layer.l2normalization_param diff --git a/tools/python/transform/shape_inference.py b/tools/python/transform/shape_inference.py index 2a7b43b6b..c919651fa 100644 --- a/tools/python/transform/shape_inference.py +++ b/tools/python/transform/shape_inference.py @@ -52,6 +52,7 @@ def __init__(self, net, input_nodes): MaceOp.PriorBox.name: self.infer_shape_prior_box, MaceOp.Reshape.name: self.infer_shape_reshape, MaceOp.ResizeBilinear.name: self.infer_shape_resize_bilinear, + MaceOp.ResizeNearestNeighbor.name: self.infer_shape_resize_nearest_neighbor, MaceOp.LpNorm.name: self.infer_shape_general, MaceOp.MVNorm.name: self.infer_shape_general, } @@ -310,3 +311,17 @@ def infer_shape_resize_bilinear(self, op): mace_check(False, "format %s is not supported" % ConverterUtil.data_format(op)) self.add_output_shape(op, [output_shape]) + + def infer_shape_resize_nearest_neighbor(self, op): + input_shape = self._output_shape_cache[op.input[0]] + size = ConverterUtil.get_arg( + op, MaceKeyword.mace_resize_size_str).ints + if ConverterUtil.data_format(op) == DataFormat.NCHW: + output_shape = [input_shape[0], input_shape[1], size[0]*input_shape[2], size[0]*input_shape[3]] + elif ConverterUtil.data_format(op) == DataFormat.NHWC: + output_shape = [input_shape[0], size[0]*input_shape[1], size[0]*input_shape[2], input_shape[3]] + else: + output_shape = [] + mace_check(False, "format %s is not supported" + % ConverterUtil.data_format(op)) + self.add_output_shape(op, [output_shape])