-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
executable file
·64 lines (52 loc) · 2.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import pickle
import time
from datetime import datetime
import torch
from config import experiment_directory
def model_data_filenames(exp_dir, objective_name):
folder_name = objective_name + '_' + datetime.now().strftime('%Y%m%d-%H:%M:%S:%f')
os.makedirs(os.path.join(exp_dir, folder_name))
logfile_dir = os.path.join(exp_dir, folder_name, 'log')
os.makedirs(logfile_dir)
model_filename = os.path.join(exp_dir, folder_name, 'model.pt')
cfg_data_filename = os.path.join(exp_dir, folder_name, 'data_config.pkl')
return model_filename, cfg_data_filename, logfile_dir
def load_model_data(path, exp_dir=experiment_directory()):
if not os.path.exists(path):
path = os.path.join(exp_dir, path)
logfile_dir = os.path.join(path, 'log')
model_filename = os.path.join(path, 'model.pt')
cfg_data_filename = os.path.join(path, 'data_config.pkl')
model = torch.load(model_filename)
cfg_data_file = open(cfg_data_filename, 'r')
cfg_data = pickle.load(cfg_data_file)
for key, value in pickle.load(cfg_data_file).iteritems():
if key != 'logfile_dir':
exec(key + '=value')
cfg_data_file.close()
return model, cfg_data, logfile_dir
def save_model_data(model, model_filename, cfg_data, cfg_data_filename):
torch.save(model, model_filename)
f = open(cfg_data_filename, 'w')
pickle.dump(cfg_data, f)
f.close()
def displaying_and_logging(logfile_dir, eval_inputs, eval_outputs, pred_mean_list, pred_std_list, pred_var_list,
time_list, elapse_list, hyper_samples, log_beta_samples, lengthscale_samples, log_order_var_samples, store_data=True):
logfile = open(os.path.join(logfile_dir, str(eval_inputs.size(0)).zfill(4) + '.out'), 'w')
for i in range(eval_inputs.size(0)):
min_val, min_ind = torch.min(eval_outputs[:i + 1], 0)
time_str = time.strftime('%H:%M:%S', time.gmtime(time_list[i])) \
+ '(' + time.strftime('%H:%M:%S', time.gmtime(elapse_list[i])) + ') '
data_str = ('%3d-th : %+12.4f, mean : %+.4E, std : %.4E, var : %.4E, min : %+8.4f(%3d)' %
(i + 1, eval_outputs.squeeze()[i],
pred_mean_list[i], pred_std_list[i], pred_var_list[i],
min_val.item(), min_ind.item() + 1))
min_str = ' <==== IMPROVED' if i == min_ind.item() else ''
print(time_str + data_str + min_str)
logfile.writelines(time_str + data_str + min_str + '\n')
logfile.close()
if store_data:
pklfilename = os.path.join(logfile_dir, str(eval_inputs.size(0)).zfill(4) + '.pkl')
torch.save({'inputs': eval_inputs, 'outputs': eval_outputs, 'hyper_samples':hyper_samples, 'log_beta_samples': log_beta_samples, 'lengthscale_samples': lengthscale_samples, 'lov':log_order_var_samples}, pklfilename)
#torch.save({'inputs': eval_inputs, 'outputs': eval_outputs}, pklfilename)