-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
123 lines (89 loc) · 4.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from sklearn_crfsuite.metrics import flat_classification_report
from module.analyze.utils import bert_labels2tokens
from data.bert_data import get_data_loader_for_predict
from data.conll2003 import conll2003_preprocess
from data import bert_data
from module.models.bert_models import BERTBiLSTMCRF
from module.train.train import NerLearner
from seqeval.metrics import classification_report
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
data_dir = "/home/muzamil/Dataset/food/Text/MyData/CoNLL_2003_dataset/project_data/"
def read_data():
data = bert_data.LearnData.create(
train_df_path=data_dir + "eng.train.train.csv",
valid_df_path=data_dir + "eng.testa.dev.csv",
idx2labels_path=data_dir + "idx2labels2.txt",
markup="BIO",
clear_cache=True,
model_name="bert-base-cased"
)
return data
if __name__ == '__main__':
conll2003_preprocess(data_dir)
data = read_data()
print(data.train_dl.dataset.df)
device = "cuda:0"
model = BERTBiLSTMCRF.create(
len(data.train_ds.idx2label), model_name="/home/muzamil/Projects/Python/NLP/MenuNER/model/FoodieBERT/cased_L-12_H-768_A-12",
lstm_dropout=0.1, crf_dropout=0.3)
# print_model_params(model, True, True)
num_epochs = 10
learner = NerLearner(
model, data, data_dir + "conll2003-BERTBiLSTMCRF-base-IO.cpt", t_total=num_epochs * len(data.train_dl))
print(model.get_n_trainable_params())
learner.fit(epochs=num_epochs)
dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config["df_path"])
print(dl.dataset.df)
preds = learner.predict(dl)
pred_tokens, pred_labels = bert_labels2tokens(dl, preds)
true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])
assert pred_tokens == true_tokens
tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)
print(tokens_report)
# Test
dl = get_data_loader_for_predict(data, df_path=data_dir + "eng.testb.dev.csv")
preds = learner.predict(dl)
pred_tokens, pred_labels = bert_labels2tokens(dl, preds)
true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])
assert pred_tokens == true_tokens
tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)
print(tokens_report)
for true_label in true_labels:
for l in range(len(true_label)):
if true_label[l].startswith('B') or true_label[l].startswith('I'):
if true_label[l] == 'B_O' or true_label[l] == 'I_O':
true_label[l] = 'O'
else:
true_label[l] = true_label[l].replace('_', '-')
for pred_label in pred_labels:
for l in range(len(pred_label)):
if pred_label[l].startswith('B') or pred_label[l].startswith('I') or pred_label[l].startswith('[PAD]'):
if pred_label[l] == 'B_O' or pred_label[l] == 'I_O' or pred_label[l].startswith('[PAD]'):
pred_label[l] = 'O'
else:
pred_label[l] = pred_label[l].replace('_', '-')
t_tokens = []
t_labels = []
p_tokens = []
p_labels = []
for i in range(len(pred_tokens)):
for j in range(len(pred_tokens[i])):
t_tokens.append(true_tokens[i][j])
t_labels.append(true_labels[i][j])
p_tokens.append(pred_tokens[i][j])
p_labels.append(pred_labels[i][j])
print("[ " + pred_tokens[i][j] + " : " + pred_labels[i][j])
print("\n")
# for i in range(len(pred_tokens)):
# for j in range(len(pred_tokens[i])):
# print("[ " + pred_tokens[i][j] + " : " + pred_labels[i][j])
# print("\n")
report = classification_report(true_labels, pred_labels, digits=4)
print("\n%s", report)
# dictionary of lists
dict = {'true_tokens': t_tokens, 'pred_tokens': p_tokens, 't_labels': t_labels, 'p_labels': p_labels}
df = pd.DataFrame(dict)
# saving the dataframe
df.to_csv('inference.csv')