-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnaive_bayes.py
executable file
·127 lines (114 loc) · 3.11 KB
/
naive_bayes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
import os
import re
from baseline import num_or_eng
import torch
alpha = 0.1
'''
create hsk dict for easy parsing
'''
def hsk_list():
HSK_DICT = {}
path = "data/HSK"
for file in os.listdir(path):
if file == 'characters' or file == 'all_vocab':
continue
file_path = path + '/' + file
level = int(file.split('HSK')[1].split('.txt')[0])
words = open(file_path, 'r').read().split('\n')
for word in words:
HSK_DICT[word] = level
torch.save(HSK_DICT, 'data/bayes_params/HSK_DICT')
'''
create naive bayes model using HSK data
HSK words as 'documents' and characters as bag of words
'''
def NB(HSK):
PK = {}
WK_IND_COUNTS = {}
WK_COUNTS = {}
total_lines = 0
for s in HSK.keys():
total_lines += 1
k = HSK[s]
if k not in PK:
PK[k] = 0
PK[k] += 1
for k in PK.keys():
count = PK[k]
PK[k] = count/total_lines
for s in HSK.keys():
k = HSK[s]
for c in s:
if c not in WK_IND_COUNTS:
WK_IND_COUNTS[c] = {}
if k not in WK_IND_COUNTS[c]:
WK_IND_COUNTS[c][k] = alpha
WK_IND_COUNTS[c][k] += 1
if k not in WK_COUNTS:
WK_COUNTS[k] = 0
WK_COUNTS[k] += 1
torch.save((PK, WK_IND_COUNTS, WK_COUNTS), 'data/bayes_params/NB')
'''
test NB with input word
'''
def test_word(word):
PK, WK_IND_COUNTS, WK_COUNTS = torch.load('data/bayes_params/NB')
PROBS = {}
for k in PK:
pk = PK[k]
prod_w = 1
for c in word:
try:
ckw = WK_IND_COUNTS[c][k]
except:
ckw = alpha
ckwp = WK_COUNTS[k]
prod_w *= ckw/ckwp
pkd = pk*prod_w
PROBS[k] = pkd
sol = max(PROBS, key=PROBS.get)
return sol
'''
compute f1 for NB
'''
def test():
total_f1 = 0
path = "data/test/segmented_text"
file_count = 0
for file in os.listdir(path):
file_count += 1
file_path = path + '/' + file
text = open(file_path, 'r').read()
segments = text.split(' ')
found = []
for s in segments:
if num_or_eng(s):
continue
advanced = test_word(s)
if advanced >= 5:
found.append(s)
real = open("data/test/vocab/" + file, 'r').read().split('\n')
false_pos = 0
true_pos = 0
false_neg = 0
for word in set(found):
if word not in real:
false_pos += 1
if word in real:
true_pos += 1
for word in real:
if word not in found:
false_neg += 1
f_score = true_pos / (true_pos + 0.5*(false_pos + false_neg))
total_f1 += f_score
print(f'{file} F1: {f_score}')
print(f'fp: {false_pos}, fn: {false_neg}, tp: {true_pos}')
print(total_f1/file_count)
return total_f1/file_count
def main():
HSK_DICT = torch.load('data/bayes_params/HSK_DICT')
#NB(HSK_DICT)
test()
if __name__ == "__main__":
main()