-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
60 lines (47 loc) · 1.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
# from bnn import BayesianNN
from bfn import BayesianFlowNetwork
from dataset import Dataset
tf.get_logger().setLevel('ERROR')
# tf.compat.v1.disable_eager_execution()
# ----------------------------------------------
print("------------------------------------")
print()
print("Making Data Class")
print()
dataset = Dataset('../amanda/test/output_csv/', "", 0.1)
print("Loading Data")
print()
# input_data = dataset.preprocess_amanda('dataset_arrays_1000')
input_data = dataset.load_data("dataset_arrays_3000")
# print("Loaded 5x data")
# cropped_input_data = [x[:1000] for x in input_data[:2]], [x[:50] for x in input_data[2:]]
# del x for x in input_data
# ----------------------------------------------
print("Building BayesianNN Model")
# model = BayesianNN(input_data[0][0].shape, input_data, 64)
# model = BayesianFlowNetwork(cropped_input_data[0][0].shape, cropped_input_data, 64)
model = BayesianFlowNetwork(input_data[0][0].shape, input_data, 64)
model.dataset_info()
model.normalize_dataset()
# model.set_summary_network()
# model.set_inference_network()
# model.set_amortized_posterior()
# model.setup_trainer()
# model.setup_model()
# ----------------------------------------------
print("Started Training")
model.train()
# ----------------------------------------------
samples=model.predict()
model.plot(samples, model.y_test[:n], "figures", "bfn_20")
# ----------------------------------------------
print("Saving the model")
model.save_model('models/model')
# model.load_model('models/model.pth')
# predictions, truths = model.infer(data_offset=2000, n_examples=50)
# model.plot(predictions, truths, "figures", "cluster_run_02_final_plot")
# predictions, truths = model.infer(data_offset=2000, n_examples=500)
# model.plot(predictions, truths, "figures", "cluster_run_02_final_plot")