forked from saikat107/Devign
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
113 lines (104 loc) · 4.84 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import copy
from sys import stderr
import numpy as np
import torch
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from tqdm import tqdm
from utils import debug
def evaluate_loss(model, loss_function, num_batches, data_iter, cuda=False):
model.eval()
with torch.no_grad():
_loss = []
all_predictions, all_targets = [], []
for _ in range(num_batches):
graph, targets = data_iter()
targets = targets.cuda()
predictions = model(graph, cuda=True)
batch_loss = loss_function(predictions, targets)
_loss.append(batch_loss.detach().cpu().item())
predictions = predictions.detach().cpu()
if predictions.ndim == 2:
all_predictions.extend(np.argmax(predictions.numpy(), axis=-1).tolist())
else:
all_predictions.extend(
predictions.ge(torch.ones(size=predictions.size()).fill_(0.5)).to(
dtype=torch.int32).numpy().tolist()
)
all_targets.extend(targets.detach().cpu().numpy().tolist())
model.train()
return np.mean(_loss).item(), accuracy_score(all_targets, all_predictions) * 100
pass
def evaluate_metrics(model, loss_function, num_batches, data_iter):
model.eval()
with torch.no_grad():
_loss = []
all_predictions, all_targets = [], []
for _ in range(num_batches):
graph, targets = data_iter()
targets = targets.cuda()
predictions = model(graph, cuda=True)
batch_loss = loss_function(predictions, targets)
_loss.append(batch_loss.detach().cpu().item())
predictions = predictions.detach().cpu()
if predictions.ndim == 2:
all_predictions.extend(np.argmax(predictions.numpy(), axis=-1).tolist())
else:
all_predictions.extend(
predictions.ge(torch.ones(size=predictions.size()).fill_(0.5)).to(
dtype=torch.int32).numpy().tolist()
)
all_targets.extend(targets.detach().cpu().numpy().tolist())
model.train()
return accuracy_score(all_targets, all_predictions) * 100, \
precision_score(all_targets, all_predictions) * 100, \
recall_score(all_targets, all_predictions) * 100, \
f1_score(all_targets, all_predictions) * 100
pass
def train(model, dataset, max_steps, dev_every, loss_function, optimizer, save_path, log_every=50, max_patience=5):
debug('Start Training')
train_losses = []
best_model = None
patience_counter = 0
best_f1 = 0
try:
for step_count in range(max_steps):
model.train()
model.zero_grad()
graph, targets = dataset.get_next_train_batch()
targets = targets.cuda()
predictions = model(graph, cuda=True)
batch_loss = loss_function(predictions, targets)
if log_every is not None and (step_count % log_every == log_every - 1):
debug('Step %d\t\tTrain Loss %10.3f' % (step_count, batch_loss.detach().cpu().item()))
train_losses.append(batch_loss.detach().cpu().item())
batch_loss.backward()
optimizer.step()
if step_count % dev_every == (dev_every - 1):
valid_loss, valid_f1 = evaluate_loss(model, loss_function, dataset.initialize_train_batch(),
dataset.get_next_train_batch)
if valid_f1 > best_f1:
patience_counter = 0
best_f1 = valid_f1
best_model = copy.deepcopy(model.state_dict())
_save_file = open(save_path + '-model.bin', 'wb')
torch.save(model.state_dict(), _save_file)
_save_file.close()
else:
patience_counter += 1
debug('Step %d\t\tTrain Loss %10.3f\tValid Loss%10.3f\tf1: %5.2f\tPatience %d' % (
step_count, np.mean(train_losses).item(), valid_loss, valid_f1, patience_counter))
debug('=' * 100)
train_losses = []
if patience_counter == max_patience:
break
except KeyboardInterrupt:
debug('Training Interrupted by user!')
if best_model is not None:
model.load_state_dict(best_model)
_save_file = open(save_path + '-model.bin', 'wb')
torch.save(model.state_dict(), _save_file)
_save_file.close()
acc, pr, rc, f1 = evaluate_metrics(model, loss_function, dataset.initialize_train_batch(),
dataset.get_next_train_batch)
debug('%s\tTest Accuracy: %0.2f\tPrecision: %0.2f\tRecall: %0.2f\tF1: %0.2f' % (save_path, acc, pr, rc, f1))
debug('=' * 100)