-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate_model.py
95 lines (76 loc) · 2.52 KB
/
evaluate_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 6 14:14:31 2019
@author: alienor
"""
import numpy as np
import os
import random
import json
import copy
from plantdb import fsdb, io
from subprocess import run as call
import argparse
from scipy.integrate import simps
default_test_dir = "/home/alienor/Documents/training2D/data/dataset_vscan_constfocal/test/"
parser = argparse.ArgumentParser(description='Test directory')
parser.add_argument('--directory', dest='directory', default=default_test_dir,
help='test dir, default: %s'%default_test_dir)
args = parser.parse_args()
db_path = args.directory
db = fsdb.FSDB(db_path)
db.connect()
scans = [scan.id for scan in db.get_scans()]
db.disconnect()
eval_scan_id = 'Evaluation'
ignore_scan_ids = ['model'] + [eval_scan_id]
classes = [
"flower",
"fruit",
"leaf",
"pedicel",
"stem"
]
bins = 100
eval = {}
tasks_eval = ['Segmentation2DEvaluation', 'SegmentedPointCloudEvaluation']#, 'PointCloudEvaluation']
# Initialisation
for c in classes:
eval[c] = {}
for task_eval in tasks_eval:
eval[c][task_eval] = {
"tp" : 0,
"fp" : 0,
"tn" : 0,
"fn" : 0
}
db = fsdb.FSDB(db_path)
db.connect()
# Iteration over the scans
for scan_id in scans:
if scan_id not in ignore_scan_ids:
scan = db.get_scan(scan_id)
for task_eval in tasks_eval:
evaluation = scan.get_fileset(task_eval)
if evaluation is None:
continue
f = evaluation.get_files()[0]
results = io.read_json(f)
for c in classes:
if c in results.keys():
eval[c][task_eval]['tp'] += results[c]['tp']
eval[c][task_eval]['fp'] += results[c]['fp']
eval[c][task_eval]['tn'] += results[c]['tn']
eval[c][task_eval]['fn'] += results[c]['fn']
for task_eval in tasks_eval:
for c in classes:
try:
eval[c][task_eval]["precision"] = eval[c][task_eval]["tp"] /( eval[c][task_eval]["tp"] + eval[c][task_eval]["fp"])
eval[c][task_eval]["recall"] = eval[c][task_eval]["tp"] / (eval[c][task_eval]["tp"] + eval[c][task_eval]["fn"])
except:
continue
eval_scan = db.get_scan(eval_scan_id, create=True)
eval_fs = eval_scan.get_fileset(eval_scan_id, create=True)
eval_file = eval_fs.create_file("eval")
io.write_json(eval_file, eval)