-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvariable_elimination.py
70 lines (60 loc) · 2.6 KB
/
variable_elimination.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pandas as pd
from BayesNet import BayesNet
from typing import Union
from typing import List
from BNReasonerOrig import BNReasoner
from ordering import Ordering
class VariableEliminator(Ordering): # accepts BN reasoner file
def __int__(self, net: Union[str, BayesNet]):
super().__init__(net)
# self.ordr = Ordering(net)
@staticmethod
def factor_multiplication(f: pd.DataFrame, g: pd.DataFrame) -> pd.DataFrame:
""" Given two factors, compute the product of the two factors. h = fg
:param f: factor 1
:param g: factor 2
:return: the product of the two factors
"""
# print(f, "\n", g)
g.rename(columns={'p': 'p1'}, inplace=True) # prevent name collision
common_variables = [v for v in f.columns if v in g.columns]
h = f.join(g.set_index(common_variables), on=common_variables).reset_index(drop=True)
h['p1'] = h['p1'] * h['p'] # multiply the probabilities
h.drop(columns=['p'], inplace=True)
h.rename(columns={'p1': 'p'}, inplace=True) # rename back to p
# print(h)
return h
def variable_elimination(self, to_remove: set[str], heuristic='manual') -> pd.DataFrame:
""" Sum out a set of variables by using variable elimination (according to given order).
:return: the resulting factor
"""
tau = pd.DataFrame()
visited = set()
if heuristic in ['manual', 'min_fill', 'min_degree']:
if heuristic == 'manual':
pass
elif heuristic == 'min_fill':
to_remove = self.min_fill(self.bn, to_remove)
else:
to_remove = self.min_degree(self.bn, to_remove)
else:
raise TypeError("Only 'manual', 'min_fill', 'min_degree' as elimination heuristics are allowed")
for element in to_remove:
if tau.empty:
tau = self.bn.get_cpt(element)
# print(element, tau)
to_visit = set(self.bn.get_children(element)) - set(visited)
for child in to_visit:
tau = self.factor_multiplication(tau, self.bn.get_cpt(child))
visited.add(child)
tau = self.marginalize(tau, element)
visited.add(element)
return tau
# with open('testResults.txt', 'w') as f:
# f.write("Variable elimination Results")
obj = VariableEliminator("testing/sleep_paralysis.BIFXML")
variables = set(obj.bn.get_all_variables()) - {'SleepPar'}
print(variables)
# variables.remove('Wet Grass?')
# variables.remove('Slippery Road?')
print(obj.variable_elimination(variables))