-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathclassifier_model_extract.py
55 lines (42 loc) · 1.75 KB
/
classifier_model_extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
from jubakit.model import JubaDump
"""
Extracting Classifier Models
=============================================
This is an example to show the usage of ``jubakit.model`` package,
which allows low-level model manipulation.
To try this example, first save a model file of jubaclassifier
(hint: ``classifier_csv.py`` example automatically saves the model under /tmp)
Then run this example like:
$ python classifier_model_extract.py /tmp/127.0.0.1_9199_example_snapshot.jubatus
to see the linear classifier weights, features and labels.
"""
# Load the model from file.
modelpath = 'classifier_iris_model.jubatus'
if 1 < len(sys.argv):
modelpath = sys.argv[1]
# load the classifier model file.
model = JubaDump.dump_file(modelpath)
# Extract Label Count
print('\n{0}\n{1}\n{2}'.format('-'*50, 'Label Information', '-'*50))
print('Count\tLabel')
label_count = model['storage']['label']['label_count']
for label, count in label_count.items():
print('{0}\t{1}'.format(count, label))
# Extract Feature Count
print('\n{0}\n{1}\n{2}'.format('-'*50, 'Feature Information', '-'*50))
print('Count\tFeature')
feature_count = model['weights']['document_frequencies']
for feature, count in feature_count.items():
print('{0}\t{1}'.format(count, feature))
# Extract Weight of Linear Classifier
print('\n{0}\n{1}\n{2}'.format('-'*50, 'Weight Information', '-'*50))
weights = model['storage']['storage']['weight']
for feature, label_values in weights.items():
print('Feature: {0}'.format(feature))
print('\tWeight \tClass')
for label, values in label_values.items():
print('\t{0:+.5f}\t{1}'.format(values['v1'], label))