-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
52 lines (45 loc) · 1.55 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
### util functions
import time
import torch
import numpy
import scipy.stats as st
from args import args
lastDisplay = time.time()
def display(string, end = '\n', force = False):
global lastDisplay
if time.time() - lastDisplay > 0.1 or force:
lastDisplay = time.time()
print(string, end=end)
def timeToStr(time):
hours = int(time) // 3600
minutes = (int(time) % 3600) // 60
seconds = int(time) % 60
return "{:d}h{:02d}m{:02d}s".format(hours, minutes, seconds)
def confInterval(scores):
if scores.shape[0] == 1:
low, up = -1., -1.
elif scores.shape[0] < 30:
low, up = st.t.interval(0.95, df = scores.shape[0] - 1, loc = scores.mean(), scale = st.sem(scores.numpy()))
else:
low, up = st.norm.interval(0.95, loc = scores.mean(), scale = st.sem(scores.numpy()))
return low, up
def createCSV(trainSet, validationSet, testSet):
if args.csv != "":
f = open(args.csv, "w")
text = "epochs, "
for datasetType in [trainSet, validationSet, testSet]:
for dataset in datasetType:
text += dataset["name"] + " loss, " + dataset["name"] + " accuracy, "
f.write(text + "\n")
f.close()
def updateCSV(stats, epoch = -1):
if args.csv != "":
f = open(args.csv, "a")
text = ""
if epoch >= 0:
text += "\n" + str(epoch) + ", "
for i in range(stats.shape[0]):
text += str(stats[i,0].item()) + ", " + str(stats[i,1].item()) + ", "
f.write(text)
f.close()
print(" utils,", end="")