-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier_io.py
60 lines (51 loc) · 2.41 KB
/
classifier_io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import numpy as np
from sklearn import __version__ as sklearn_version
# This class is used to save and load classifier models in a json format
# The model is saved as a dictionary, with the model parameters and the sklearn version
# The model can be loaded back into a sklearn model object
class ClassifierModel:
def __init__(self, model, sklearn_version=None):
self.model = model
self.model_params = model.__dict__
if sklearn_version is None:
self.sklearn_version = sklearn_version
def save_json(self, filepath):
for k, v in self.model_params.items():
if isinstance(v, np.ndarray):
self.model_params[k] = v.tolist()
self.set_sklearn_version(self.sklearn_version)
json_text = json.dumps(self.model_params)
with open(filepath, "w") as file:
file.write(json_text)
def load_json(self, filepath):
with open(filepath, "r") as file:
self.model_params = json.load(file)
for k, v in self.model_params.items():
if isinstance(v, list):
self.model_params[k] = np.asarray(v)
self.model.__dict__ = self.model_params
return self
def get(self):
return self.model
def set_sklearn_version(self, vstring):
if vstring:
self.model_params.update({"sklearn_version": vstring})
else:
raise Warning(
"The sklearn version is undefined. This information is critical for resolving future compatibility issues."
)
def define_features(self, dims, coords):
"""
xarray-style definition of matrix dimensions and coordinates, for interpretability of classifier weights
Classifier weights are stored as a 1d list representing a multidimensional feature set. first weight corresponds
to dim_0[0]-dim_1[0] feature pair, second weight corresponds to dim_0[0]-dim_1[1] and so on.
Parameters:
dims - list (ordered) of feature names or "dimensions", like "channel" or "frequency". The first dimension
designates the "outer loop" for iterating through the feature set.
coords - dictionary mapping dims to lists of values.
"""
for dim in dims:
if not (dim in coords):
raise IndexError(f"dimension {dim} not found in coords")
self.model_params.update({"dims": dims, "coords": coords})