Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Drawing Classifier: Last batch in training and eval was not being loa…
Browse files Browse the repository at this point in the history
…ded in all contexts. (#1637)

Drawing Classifier training and inference now works on Linux GPU.
  • Loading branch information
shantanuchhabra authored Mar 20, 2019
1 parent e7a5ed0 commit 3e22997
Showing 1 changed file with 38 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,11 @@ def create(input_dataset, target, feature=None, validation_set='auto',
def get_data_and_label_from_batch(batch):
if batch.pad is not None:
size = batch_size - batch.pad
batch_data = (
[_mx.nd.slice_axis(batch.data[0], axis=0, begin=0, end=size)]
+ [None] * (len(ctx)-1)
)
batch_label = (
[_mx.nd.slice_axis(batch.label[0], axis=0, begin=0, end=size)]
+ [None] * (len(ctx)-1)
)
sliced_data = _mx.nd.slice_axis(batch.data[0], axis=0, begin=0, end=size)
sliced_label = _mx.nd.slice_axis(batch.label[0], axis=0, begin=0, end=size)
num_devices = min(sliced_data.shape[0], len(ctx))
batch_data = _mx.gluon.utils.split_and_load(sliced_data, ctx_list=ctx[:num_devices], even_split=False)
batch_label = _mx.gluon.utils.split_and_load(sliced_label, ctx_list=ctx[:num_devices], even_split=False)
else:
batch_data = _mx.gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
batch_label = _mx.gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
Expand Down Expand Up @@ -534,14 +531,18 @@ def _predict_with_probabilities(self, input_dataset, batch_size=None,

dataset_size = len(dataset)
ctx = _mxnet_utils.get_mxnet_context()

all_predicted = ['']*dataset_size
all_probabilities = _np.zeros((dataset_size, len(self.classes)),
dtype=float)

index = 0
last_time = 0
done = False

from turicreate import SArrayBuilder
from array import array

classes = self.classes
all_predicted_builder = SArrayBuilder(dtype=type(classes[0]))
all_probabilities_builder = SArrayBuilder(dtype=array)

for batch in loader:
if batch.pad is not None:
size = batch_size - batch.pad
Expand All @@ -551,35 +552,31 @@ def _predict_with_probabilities(self, input_dataset, batch_size=None,
batch_data = batch.data[0]
size = batch_size

if batch_data.shape[0] < len(ctx):
ctx0 = ctx[:batch_data.shape[0]]
else:
ctx0 = ctx

z = self._model(batch_data).asnumpy()
predicted = z.argmax(axis=1)
classes = self.classes

predicted_sa = _tc.SArray(predicted).apply(lambda x: classes[x])

all_predicted[index : index + len(predicted_sa)] = predicted_sa
all_probabilities[index : index + z.shape[0]] = z
index += z.shape[0]
if index == dataset_size - 1:
done = True

cur_time = _time.time()
# Do not print process if only a few samples are predicted
if verbose and (dataset_size >= 5
and cur_time > last_time + 10 or done):
print('Predicting {cur_n:{width}d}/{max_n:{width}d}'.format(
cur_n = index,
max_n = dataset_size,
width = len(str(dataset_size))))
last_time = cur_time

return (_tc.SFrame({self.target: _tc.SArray(all_predicted),
'probability': _tc.SArray(all_probabilities)}))
num_devices = min(batch_data.shape[0], len(ctx))
split_data = _mx.gluon.utils.split_and_load(batch_data, ctx_list=ctx[:num_devices], even_split=False)

for data in split_data:
z = self._model(data).asnumpy()
predicted = list(map(lambda x: classes[x], z.argmax(axis=1)))
split_length = z.shape[0]
all_predicted_builder.append_multiple(predicted)
all_probabilities_builder.append_multiple(z.tolist())
index += split_length
if index == dataset_size - 1:
done = True

cur_time = _time.time()
# Do not print progress if only a few samples are predicted
if verbose and (dataset_size >= 5
and cur_time > last_time + 10 or done):
print('Predicting {cur_n:{width}d}/{max_n:{width}d}'.format(
cur_n = index + 1,
max_n = dataset_size,
width = len(str(dataset_size))))
last_time = cur_time

return (_tc.SFrame({self.target: all_predicted_builder.close(),
'probability': all_probabilities_builder.close()}))

def evaluate(self, dataset, metric='auto', batch_size=None, verbose=True):
"""
Expand Down

0 comments on commit 3e22997

Please sign in to comment.