Skip to content

Commit

Permalink
Merge pull request #131 from jubatus/add-tensorboard-example
Browse files Browse the repository at this point in the history
add TensorBoard visualization example
  • Loading branch information
rimms authored Jan 24, 2019
2 parents 25aa596 + 5834259 commit e588ae0
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ See the `example <https://github.com/jubatus/jubakit/tree/master/example>`_ dire
+-----------------------------------+-----------------------------------------------+-----------------------+
| classifier_sklearn_grid_search.py | Grid Search example using scikit-learn wrapper| ✓ |
+-----------------------------------+-----------------------------------------------+-----------------------+
| classifier_tensorboard.py | Visualize a training process using TensorBoard| ✓ |
+-----------------------------------+-----------------------------------------------+-----------------------+
| regression_boston.py | Regression with toy dataset (boston) ||
+-----------------------------------+-----------------------------------------------+-----------------------+
| regression_csv.py | Regression with CSV file | |
Expand Down
104 changes: 104 additions & 0 deletions example/classifier_tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals

"""
Visualize training process with TensorBoard
===========================================
In this example, we show the training process of Jubatus with TensorBoard.
TensorBoard syntax is little complicated and in this example we use tensorboardX library.
tensorboardX is a simple wrapper of TensorBoard that write events with simple function call.
[How to Use]
1. Install tensorboard.
```
$ pip install tensorboardX
```
2. Run this example.
3. Check the training process using tensorboard.
```
$ tensorboard --logdir runs/***
```
4. Enjoy!
"""

from sklearn.datasets import load_digits
from sklearn.metrics import (
accuracy_score, f1_score, precision_score, recall_score, log_loss)

from tensorboardX import SummaryWriter

import jubakit
from jubakit.classifier import Classifier, Dataset, Config
from jubakit.model import JubaDump

# Load the digits dataset.
digits = load_digits()

# Create a dataset.
dataset = Dataset.from_array(digits.data, digits.target)
n_samples = len(dataset)
n_train_samples = int(n_samples * 0.7)
train_ds = dataset[:n_train_samples]
test_ds = dataset[n_train_samples:]

# Create a classifier.
config = Config(method='AROW',
parameter={'regularization_weight': 0.1})
classifier = Classifier.run(config)

model_name = 'classifier_digits'
model_path = '/tmp/{}_{}_classifier_{}.jubatus'.format(
classifier._host, classifier._port, model_name)

# show the feature weights of the target label.
target_label = 4

# Initialize summary writer.
writer = SummaryWriter()

# train and test the classifier.
epochs = 100
for epoch in range(epochs):
# train
for _ in classifier.train(train_ds): pass

# test
y_true, y_pred = [], []
for (_, label, result) in classifier.classify(test_ds):
y_true.append(label)
y_pred.append(result[0][0])

# save model to check the feature weights
classifier.save(model_name)

model = JubaDump.dump_file(model_path)
weights = model['storage']['storage']['weight']
for feature, label_values in weights.items():
for label, value in label_values.items():
if str(label) != str(target_label):
continue
writer.add_scalar('weights/{}'.format(feature), value['v1'], epoch)

# write scores to tensorboardX summary writer.
acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')
writer.add_scalar('metrics/accuracy', acc, epoch)
writer.add_scalar('metrics/precision', prec, epoch)
writer.add_scalar('metrics/recall', recall, epoch)
writer.add_scalar('metrics/f1_score', f1, epoch)

writer.close()
classifier.stop()
3 changes: 2 additions & 1 deletion tools/run_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from urllib.request import urlopen

# Accept exceptions in following examples.
BLACK_LIST = ['classifier_twitter.py', 'classifier_hyperopt_tuning.py']
BLACK_LIST = ['classifier_twitter.py', 'classifier_hyperopt_tuning.py',
'classifier_tensorboard.py']

def download_bzip2(path, url):
if os.path.exists(path): return
Expand Down

0 comments on commit e588ae0

Please sign in to comment.