Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Drawing Classifier: Improved error handling for Validation Set (#1623)
Browse files Browse the repository at this point in the history
Added error handling for validation_set != auto but still being a string, and for datasets smaller than 100.
  • Loading branch information
shantanuchhabra authored Mar 19, 2019
1 parent 11ad82a commit e7a5ed0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,3 @@ def get_model_path(self):
)[0]
return model_path


Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,34 @@ def create(input_dataset, target, feature=None, validation_set='auto',
classes = sorted(classes)
class_to_index = {name: index for index, name in enumerate(classes)}

validation_set_corrective_string = ("'validation_set' parameter must be "
+ "an SFrame, or None, or must be set to 'auto' for the toolkit to "
+ "automatically create a validation set.")
if isinstance(validation_set, _tc.SFrame):
_raise_error_if_not_drawing_classifier_input_sframe(
validation_set, feature, target)
is_validation_stroke_input = (validation_set[feature].dtype != _tc.Image)
validation_dataset = _extensions._drawing_classifier_prepare_data(
validation_set, feature) if is_validation_stroke_input else validation_set
elif isinstance(validation_set, str):
assert (validation_set == 'auto')
if dataset.num_rows() >= 100:
if verbose:
print ( "PROGRESS: Creating a validation set from 5 percent of training data. This may take a while.\n"
" You can set ``validation_set=None`` to disable validation tracking.\n")
dataset, validation_dataset = dataset.random_split(
TRAIN_VALIDATION_SPLIT)
if validation_set == 'auto':
if dataset.num_rows() >= 100:
if verbose:
print ( "PROGRESS: Creating a validation set from 5 percent of training data. This may take a while.\n"
" You can set ``validation_set=None`` to disable validation tracking.\n")
dataset, validation_dataset = dataset.random_split(
TRAIN_VALIDATION_SPLIT)
else:
validation_set = None
validation_dataset = _tc.SFrame()
else:
validation_dataset = _tc.SFrame()
raise _ToolkitError("Unrecognized value for 'validation_set'. "
+ validation_set_corrective_string)
elif validation_set is None:
validation_dataset = _tc.SFrame()
else:
raise TypeError("Unrecognized type for 'validation_set'.")
raise TypeError("Unrecognized type for 'validation_set'."
+ validation_set_corrective_string)

train_loader = _SFrameClassifierIter(dataset, batch_size,
feature_column=feature,
Expand Down

0 comments on commit e7a5ed0

Please sign in to comment.