-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy patheval.py
76 lines (68 loc) · 3.42 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
import time
from common import set_custom_seed
from core.plugins.storage import ModelLoader
from distances import EuclideanDistance, CosineDistance
from experiments.semeval import SemEvalEmbeddingEvaluationExperiment, SemEvalBaselineModelEvaluationExperiment
from experiments.voxceleb import VoxCeleb1ModelEvaluationExperiment
launch_datetime = time.strftime('%c')
# Script arguments
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, required=True, help='speaker / sts')
parser.add_argument('--model', type=str, required=True, help='The path to the saved model to evaluate')
parser.add_argument('--partition', type=str, required=True, help='dev / test')
parser.add_argument('--distance', type=str, default='euclidean', help='cosine / euclidean. Default: euclidean')
parser.add_argument('--batch-size', type=int, default=100, help='Batch size for training and testing')
parser.add_argument('--sts-path', type=str, default=None, help='Path to SemEval dataset')
parser.add_argument('--vocab', type=str, default=None, help='Path to vocabulary file for STS')
parser.add_argument('--word2vec', type=str, default=None, help='Path to word embeddings for STS')
parser.add_argument('--log-interval', type=int, default=10,
help='Steps (in percentage) to show evaluation progress, only for STS. Default: 10')
parser.add_argument('--seed', type=int, default=None, help='Random seed')
parser.add_argument('--exp-id', type=str, default=f"EXP-{launch_datetime.replace(' ', '-')}",
help='An identifier for the experience')
args = parser.parse_args()
# Set custom seed
set_custom_seed(args.seed)
if args.distance == 'cosine':
distance = CosineDistance()
elif args.distance == 'euclidean':
distance = EuclideanDistance()
else:
raise ValueError("Distance can only be: cosine / euclidean")
print(f"[Task: {args.task.upper()}]")
print('[Preparing...]')
if args.task == 'speaker':
experiment = VoxCeleb1ModelEvaluationExperiment(model_path=args.model,
nfeat=256,
distance=distance,
batch_size=args.batch_size)
metric_name = 'EER'
elif args.task == 'sts':
model_loader = ModelLoader(args.model)
loss_name = model_loader.get_trained_loss()
if loss_name == 'kldiv':
experiment_type = SemEvalBaselineModelEvaluationExperiment
else:
experiment_type = SemEvalEmbeddingEvaluationExperiment
experiment = experiment_type(model_loader=model_loader,
nfeat=500,
data_path=args.sts_path,
word2vec_path=args.word2vec,
vocab_path=args.vocab,
distance=distance,
log_interval=args.log_interval,
batch_size=args.batch_size,
base_dir='tmp')
metric_name = 'Spearman'
else:
raise ValueError("Task can only be 'speaker' or 'sts'")
print('[Started Evaluation...]')
if args.partition == 'dev':
metric = experiment.evaluate_on_dev(True)
elif args.partition == 'test':
metric = experiment.evaluate_on_test()
else:
raise ValueError('Partition can only be: dev / test')
print(f"[Evaluation Finished]")
print(f"[{args.partition.upper()} {metric_name} = {metric}]")