Skip to content

Commit 99c9ec1

Browse files
authored
Comprehensive SparseZoo tests (neuralmagic#24)
Added tests for model helper functions and thorough tests of all models available in zoo. Added fixes to support yolo v3 labels
1 parent 4ff2385 commit 99c9ec1

File tree

14 files changed

+1135
-34
lines changed

14 files changed

+1135
-34
lines changed

Makefile

+26-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@ DOCDIR := docs
77
MDCHECKGLOBS := 'docs/**/*.md' 'docs/**/*.rst' 'examples/**/*.md' 'notebooks/**/*.md' 'scripts/**/*.md'
88
MDCHECKFILES := CODE_OF_CONDUCT.md CONTRIBUTING.md DEVELOPING.md README.md
99

10+
TARGETS := "" # targets for running pytests: full,efficientnet,inception,resnet,vgg,ssd,yolo
11+
PYTEST_ARGS := ""
12+
ifneq ($(findstring full,$(TARGETS)),full)
13+
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparsezoo/models/test_zoo_extensive.py
14+
endif
15+
ifneq ($(findstring efficientnet,$(TARGETS)),efficientnet)
16+
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparsezoo/models/classification/test_efficientnet.py
17+
endif
18+
ifneq ($(findstring inception,$(TARGETS)),inception)
19+
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparsezoo/models/classification/test_inception.py
20+
endif
21+
ifneq ($(findstring resnet,$(TARGETS)),resnet)
22+
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparsezoo/models/classification/test_resnet.py
23+
endif
24+
ifneq ($(findstring vgg,$(TARGETS)),vgg)
25+
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparsezoo/models/classification/test_vgg.py
26+
endif
27+
ifneq ($(findstring ssd,$(TARGETS)),ssd)
28+
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparsezoo/models/detection/test_ssd.py
29+
endif
30+
ifneq ($(findstring yolo,$(TARGETS)),yolo)
31+
PYTEST_ARGS := $(PYTEST_ARGS) --ignore tests/sparsezoo/models/detection/test_yolo.py
32+
endif
33+
34+
1035
# run checks on all files for the repo
1136
quality:
1237
@echo "Running copyright checks";
@@ -27,7 +52,7 @@ style:
2752
# run tests for the repo
2853
test:
2954
@echo "Running python tests";
30-
@pytest;
55+
pytest tests $(PYTEST_ARGS);
3156

3257
# create docs
3358
docs:

setup.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
"black>=20.8b1",
2929
"flake8>=3.8.3",
3030
"isort>=5.7.0",
31+
"onnxruntime>=1.0.0",
32+
"pytest>=6.0.0",
3133
"rinohtype>=0.4.2",
3234
"recommonmark>=0.7.0",
3335
"sphinx>=3.4.0",
@@ -36,6 +38,7 @@
3638
"wheel>=0.36.2",
3739
"pytest>=6.0.0",
3840
"sphinx-rtd-theme",
41+
"wheel>=0.36.2",
3942
]
4043

4144

src/sparsezoo/models/detection/ssd.py

+11
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def ssd_resnet50_300(
3434
optim_name: str = "base",
3535
optim_category: str = "none",
3636
optim_target: Union[str, None] = None,
37+
override_folder_name: Union[str, None] = None,
38+
override_parent_path: Union[str, None] = None,
39+
force_token_refresh: bool = False,
3740
) -> Model:
3841
"""
3942
Convenience function for getting an ssd resnet50 300 model
@@ -53,6 +56,11 @@ def ssd_resnet50_300(
5356
moderate (>=99% baseline metric), aggressive (<99% baseline metric)
5457
:param optim_target: The deployment target of optimization of the model
5558
the object belongs to; e.g. edge, deepsparse, deepsparse_throughput, gpu
59+
:param override_folder_name: Override for the name of the folder to save
60+
this file under
61+
:param override_parent_path: Path to override the default save path
62+
for where to save the parent folder for this file under
63+
:param force_token_refresh: True to refresh the auth token, False otherwise
5664
:return: The created model
5765
"""
5866
return Zoo.load_model(
@@ -67,4 +75,7 @@ def ssd_resnet50_300(
6775
optim_name=optim_name,
6876
optim_category=optim_category,
6977
optim_target=optim_target,
78+
override_folder_name=override_folder_name,
79+
override_parent_path=override_parent_path,
80+
force_token_refresh=force_token_refresh,
7081
)

src/sparsezoo/models/detection/yolo.py

+11
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def yolo_v3(
3535
optim_name: str = "base",
3636
optim_category: str = "none",
3737
optim_target: Union[str, None] = None,
38+
override_folder_name: Union[str, None] = None,
39+
override_parent_path: Union[str, None] = None,
40+
force_token_refresh: bool = False,
3841
) -> Model:
3942
"""
4043
Convenience function for getting an ssd resnet50 300 model
@@ -56,6 +59,11 @@ def yolo_v3(
5659
moderate (>=99% baseline metric), aggressive (<99% baseline metric)
5760
:param optim_target: The deployment target of optimization of the model
5861
the object belongs to; e.g. edge, deepsparse, deepsparse_throughput, gpu
62+
:param override_folder_name: Override for the name of the folder to save
63+
this file under
64+
:param override_parent_path: Path to override the default save path
65+
for where to save the parent folder for this file under
66+
:param force_token_refresh: True to refresh the auth token, False otherwise
5967
:return: The created model
6068
"""
6169
return Zoo.load_model(
@@ -70,4 +78,7 @@ def yolo_v3(
7078
optim_name=optim_name,
7179
optim_category=optim_category,
7280
optim_target=optim_target,
81+
override_folder_name=override_folder_name,
82+
override_parent_path=override_parent_path,
83+
force_token_refresh=force_token_refresh,
7384
)

src/sparsezoo/utils/numpy.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ class NumpyArrayBatcher(object):
169169

170170
def __init__(self):
171171
self._items = OrderedDict() # type: Dict[str, List[numpy.ndarray]]
172+
self._batch_index = None
172173

173174
def __len__(self):
174175
if len(self._items) == 0:
@@ -189,12 +190,20 @@ def append(self, item: Union[numpy.ndarray, Dict[str, numpy.ndarray]]):
189190
for key, val in item.items():
190191
self._items[key] = [val]
191192
elif isinstance(item, numpy.ndarray):
193+
if self._batch_index is None:
194+
self._batch_index = {NDARRAY_KEY: 0}
192195
if NDARRAY_KEY not in self._items:
193196
raise ValueError(
194197
"numpy ndarray passed for item, but prev_batch does not contain one"
195198
)
196199

197200
if item.shape != self._items[NDARRAY_KEY][0].shape:
201+
self._batch_index[NDARRAY_KEY] = 1
202+
203+
if item.shape != self._items[NDARRAY_KEY][0].shape and (
204+
item.shape[0] != self._items[NDARRAY_KEY][0].shape[0]
205+
or item.shape[2:] != self._items[NDARRAY_KEY][0].shape[2:]
206+
):
198207
raise ValueError(
199208
(
200209
f"item of numpy ndarray of shape {item.shape} does not "
@@ -215,8 +224,17 @@ def append(self, item: Union[numpy.ndarray, Dict[str, numpy.ndarray]]):
215224
)
216225
)
217226

227+
if self._batch_index is None:
228+
self._batch_index = {key: 0 for key in item}
229+
218230
for key, val in item.items():
219231
if val.shape != self._items[key][0].shape:
232+
self._batch_index[key] = 1
233+
234+
if val.shape != self._items[key][0].shape and (
235+
val.shape[0] != self._items[key][0].shape[0]
236+
or val.shape[2:] != self._items[key][0].shape[2:]
237+
):
220238
raise ValueError(
221239
(
222240
f"item with key {key} of shape {val.shape} does not "
@@ -240,8 +258,10 @@ def stack(
240258
batch_dict = OrderedDict()
241259

242260
for key, val in self._items.items():
243-
batch_dict[key] = numpy.stack(self._items[key])
244-
261+
if self._batch_index is None or self._batch_index[key] == 0:
262+
batch_dict[key] = numpy.stack(val)
263+
else:
264+
batch_dict[key] = numpy.concatenate(val, axis=self._batch_index[key])
245265
return batch_dict if not as_list else list(batch_dict.values())
246266

247267

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from sparsezoo.models.classification import efficientnet_b0, efficientnet_b4
18+
from tests.sparsezoo.utils import model_constructor
19+
20+
21+
@pytest.mark.parametrize(
22+
(
23+
"download,framework,repo,dataset,training_scheme,"
24+
"optim_name,optim_category,optim_target"
25+
),
26+
[
27+
(True, "pytorch", "sparseml", "imagenet", None, "base", "none", None),
28+
(True, "pytorch", "sparseml", "imagenet", None, "arch", "moderate", None),
29+
],
30+
)
31+
def test_efficientnet_b0(
32+
download,
33+
framework,
34+
repo,
35+
dataset,
36+
training_scheme,
37+
optim_name,
38+
optim_category,
39+
optim_target,
40+
):
41+
model_constructor(
42+
efficientnet_b0,
43+
download,
44+
framework,
45+
repo,
46+
dataset,
47+
training_scheme,
48+
optim_name,
49+
optim_category,
50+
optim_target,
51+
)
52+
53+
54+
@pytest.mark.parametrize(
55+
(
56+
"download,framework,repo,dataset,training_scheme,"
57+
"optim_name,optim_category,optim_target"
58+
),
59+
[
60+
(True, "pytorch", "sparseml", "imagenet", None, "base", "none", None),
61+
(True, "pytorch", "sparseml", "imagenet", None, "arch", "moderate", None),
62+
],
63+
)
64+
def test_efficientnet_b4(
65+
download,
66+
framework,
67+
repo,
68+
dataset,
69+
training_scheme,
70+
optim_name,
71+
optim_category,
72+
optim_target,
73+
):
74+
model_constructor(
75+
efficientnet_b4,
76+
download,
77+
framework,
78+
repo,
79+
dataset,
80+
training_scheme,
81+
optim_name,
82+
optim_category,
83+
optim_target,
84+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from sparsezoo.models.classification import inception_v3
18+
from tests.sparsezoo.utils import model_constructor
19+
20+
21+
@pytest.mark.parametrize(
22+
(
23+
"download,framework,repo,dataset,training_scheme,"
24+
"optim_name,optim_category,optim_target"
25+
),
26+
[
27+
(True, "pytorch", "sparseml", "imagenet", None, "base", "none", None),
28+
(True, "pytorch", "sparseml", "imagenet", None, "pruned", "conservative", None),
29+
(True, "pytorch", "sparseml", "imagenet", None, "pruned", "moderate", None),
30+
],
31+
)
32+
def test_inception_v3(
33+
download,
34+
framework,
35+
repo,
36+
dataset,
37+
training_scheme,
38+
optim_name,
39+
optim_category,
40+
optim_target,
41+
):
42+
model_constructor(
43+
inception_v3,
44+
download,
45+
framework,
46+
repo,
47+
dataset,
48+
training_scheme,
49+
optim_name,
50+
optim_category,
51+
optim_target,
52+
)

0 commit comments

Comments
 (0)