Skip to content

Commit

Permalink
Remove NULLs from domain: always predict a non-NULL value.
Browse files Browse the repository at this point in the history
All of the PR in one commit.
  • Loading branch information
richardwu authored and minafarid committed Feb 20, 2019
1 parent 8513f14 commit 2d43f34
Show file tree
Hide file tree
Showing 17 changed files with 383 additions and 197 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ before_script:
- psql -U postgres -c 'create database holo;'
- psql -U postgres -c 'CREATE USER holocleanuser;'
- psql -U postgres -c "ALTER USER holocleanuser WITH PASSWORD 'abcd1234';"
- psql -U postgres -c 'ALTER USER holocleanuser WITH SUPERUSER;'
- psql -U postgres -c 'GRANT ALL PRIVILEGES ON DATABASE holo TO holocleanuser;'
- psql -U postgres -d holo -c 'ALTER SCHEMA public OWNER TO holocleanuser;'

Expand Down
25 changes: 19 additions & 6 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .dbengine import DBengine
from .table import Table, Source
from utils import dictify_df
from utils import dictify_df, NULL_REPR


class AuxTables(Enum):
Expand Down Expand Up @@ -100,8 +100,8 @@ def load_data(self, name, fpath, na_values=None, entity_col=None, src_col=None):
# use entity IDs as _tid_'s directly
df.rename({entity_col: '_tid_'}, axis='columns', inplace=True)

# Use '_nan_' to represent NULL values
df.fillna('_nan_', inplace=True)
# Use NULL_REPR to represent NULL values
df.fillna(NULL_REPR, inplace=True)

logging.info("Loaded %d rows with %d cells", self.raw_data.df.shape[0], self.raw_data.df.shape[0] * self.raw_data.df.shape[1])

Expand Down Expand Up @@ -209,6 +209,12 @@ def get_statistics(self):
<val1>: all values of <attr1>
<val2>: values of <attr2> that appear at least once with <val1>.
<count>: frequency (# of entities) where attr1=val1 AND attr2=val2
NB: neither single_attr_stats nor pair_attr_stats contain frequencies
for values that are NULL (NULL_REPR). One would need to explicitly
check if the value is NULL before lookup.
Also, values that only co-occur with NULLs will NOT be in pair_attr_stats.
"""
if not self.stats_ready:
logging.debug('computing frequency and co-occurrence statistics from raw data...')
Expand Down Expand Up @@ -242,7 +248,7 @@ def collect_stats(self):
self.pair_attr_stats[cond_attr] = {}
for trg_attr in self.get_attributes():
if trg_attr != cond_attr:
self.pair_attr_stats[cond_attr][trg_attr] = self.get_stats_pair(cond_attr,trg_attr)
self.pair_attr_stats[cond_attr][trg_attr] = self.get_stats_pair(cond_attr, trg_attr)

def get_stats_single(self, attr):
"""
Expand All @@ -251,16 +257,23 @@ def get_stats_single(self, attr):
"""
# need to decode values into unicode strings since we do lookups via
# unicode strings from Postgres
return self.get_raw_data()[[attr]].groupby([attr]).size().to_dict()
data_df = self.get_raw_data()
return data_df[[attr]].loc[data_df[attr] != NULL_REPR].groupby([attr]).size().to_dict()

def get_stats_pair(self, first_attr, second_attr):
"""
Returns a dictionary {first_val -> {second_val -> count } } where:
<first_val>: all possible values for first_attr
<second_val>: all values for second_attr that appear at least once with <first_val>
<count>: frequency (# of entities) where first_attr=<first_val> AND second_attr=<second_val>
Filters out NULL values so no entries in the dictionary would have NULLs.
"""
tmp_df = self.get_raw_data()[[first_attr, second_attr]].groupby([first_attr, second_attr]).size().reset_index(name="count")
data_df = self.get_raw_data()
tmp_df = data_df[[first_attr, second_attr]]\
.loc[(data_df[first_attr] != NULL_REPR) & (data_df[second_attr] != NULL_REPR)]\
.groupby([first_attr, second_attr])\
.size()\
.reset_index(name="count")
return dictify_df(tmp_df)

def get_domain_info(self):
Expand Down
3 changes: 2 additions & 1 deletion detect/nulldetector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd

from .detector import Detector
from utils import NULL_REPR


class NullDetector(Detector):
Expand Down Expand Up @@ -28,7 +29,7 @@ def detect_noisy_cells(self):
attributes = self.ds.get_attributes()
errors = []
for attr in attributes:
tmp_df = self.df[self.df[attr] == '_nan_']['_tid_'].to_frame()
tmp_df = self.df[self.df[attr] == NULL_REPR]['_tid_'].to_frame()
tmp_df.insert(1, "attribute", attr)
errors.append(tmp_df)
errors_df = pd.concat(errors, ignore_index=True)
Expand Down
165 changes: 93 additions & 72 deletions domain/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from dataset import AuxTables, CellStatus
from .estimators import NaiveBayes
from utils import NULL_REPR


class DomainEngine:
Expand Down Expand Up @@ -186,11 +187,12 @@ def get_corr_attributes(self, attr, thres):
that are correlated with attr with magnitude at least self.cor_strength
(init parameter).
:param attr: (string) the original attribute to get the correlated attributes for.
:param thres: (float) correlation threshold (absolute) for returned attributes.
"""
# Not memoized: find correlated attributes from correlation dictionary.
if (attr, thres) not in self._corr_attrs:
self._corr_attrs[(attr,thres)] = []
self._corr_attrs[(attr, thres)] = []

if attr in self.correlations:
attr_correlations = self.correlations[attr]
Expand All @@ -203,10 +205,8 @@ def get_corr_attributes(self, attr, thres):
def generate_domain(self):
"""
Generates the domain for each cell in the active attributes as well
as assigns variable IDs (_vid_) (increment key from 0 onwards, depends on
iteration order of rows/entities in raw data and attributes.
Note that _vid_ has a 1-1 correspondence with _cid_.
as assigns a random variable ID (_vid_) for cells that have
a domain of size >= 2.
See get_domain_cell for how the domain is generated from co-occurrence
and correlated attributes.
Expand All @@ -216,8 +216,8 @@ def generate_domain(self):
:return: DataFrame with columns
_tid_: entity/tuple ID
_cid_: cell ID (unique for every entity-attribute)
_vid_: variable ID (1-1 correspondence with _cid_)
_cid_: cell ID (one for every cell in the raw data in active attributes)
_vid_: random variable ID (one for every cell with a domain of at least size 2)
attribute: attribute name
domain: ||| separated string of domain values
domain_size: length of domain
Expand All @@ -239,46 +239,59 @@ def generate_domain(self):
self.all_attrs = list(records.dtype.names)
for row in tqdm(list(records)):
tid = row['_tid_']
app = []
for attr in self.active_attributes:
init_value, init_value_idx, dom = self.get_domain_cell(attr, row)
# We will use an estimator model for additional weak labelling
# below, which requires an initial pruned domain first.
weak_label = init_value
weak_label_idx = init_value_idx
if len(dom) > 1:
cid = self.ds.get_cell_id(tid, attr)
app.append({"_tid_": tid,
"attribute": attr,
"_cid_": cid,
"_vid_": vid,
"domain": "|||".join(dom),
"domain_size": len(dom),
"init_value": init_value,
"init_index": init_value_idx,
"weak_label": weak_label,
"weak_label_idx": weak_label_idx,
"fixed": CellStatus.NOT_SET.value})
vid += 1
else:
add_domain = self.get_random_domain(attr, init_value)
# Check if attribute has more than one unique values.
if len(add_domain) > 0:
dom.extend(add_domain)
cid = self.ds.get_cell_id(tid, attr)
app.append({"_tid_": tid,
"attribute": attr,
"_cid_": cid,
"_vid_": vid,
"domain": "|||".join(dom),
"domain_size": len(dom),
"init_value": init_value,
"init_index": init_value_idx,
"weak_label": init_value,
"weak_label_idx": init_value_idx,
"fixed": CellStatus.SINGLE_VALUE.value})
vid += 1
cells.extend(app)
# Weak labels will be trained on the init values.
cid = self.ds.get_cell_id(tid, attr)

# Originally, all cells have a NOT_SET status to be considered
# in weak labelling.
cell_status = CellStatus.NOT_SET.value

if len(dom) <= 1:
# Initial value is NULL and we cannot come up with
# a domain; a random domain probably won't help us so
# completely ignore this cell and continue.
if init_value == NULL_REPR:
continue

# Not enough domain values, we need to get some random
# values (other than 'init_value') for training. However,
# this might still get us zero domain values.
rand_dom_values = self.get_random_domain(attr, init_value)

# rand_dom_values might still be empty. In this case,
# there are no other possible values for this cell. There
# is not point to use this cell for training and there is no
# point to run inference on it since we cannot even generate
# a random domain. Therefore, we just ignore it from the
# final tensor.
if len(rand_dom_values) == 0:
continue

# Otherwise, just add the random domain values to the domain
# and set the cell status accordingly.
dom.extend(rand_dom_values)

# Set the cell status that this is a single value and was
# randomly assigned other values in the domain. These will
# not be modified by the estimator.
cell_status = CellStatus.SINGLE_VALUE.value

cells.append({"_tid_": tid,
"attribute": attr,
"_cid_": cid,
"_vid_": vid,
"domain": "|||".join(dom),
"domain_size": len(dom),
"init_value": init_value,
"init_index": init_value_idx,
"weak_label": init_value,
"weak_label_idx": init_value_idx,
"fixed": cell_status})
vid += 1
domain_df = pd.DataFrame(data=cells).sort_values('_vid_')
logging.debug('DONE generating initial set of domain values in %.2f', time.clock() - tic)

Expand Down Expand Up @@ -319,14 +332,18 @@ def generate_domain(self):
domain_values = [val for val, proba in sorted(preds, key=lambda pred: pred[1], reverse=True)[:self.max_domain]]

# ensure the initial value is included even if its probability is low.
if row['init_value'] not in domain_values:
if row['init_value'] not in domain_values and row['init_value'] != NULL_REPR:
domain_values.append(row['init_value'])
domain_values = sorted(domain_values)
# update our memoized domain values for this row again
row['domain'] = '|||'.join(domain_values)
row['domain_size'] = len(domain_values)
row['weak_label_idx'] = domain_values.index(row['weak_label'])
row['init_index'] = domain_values.index(row['init_value'])
# update init index based on new domain
if row['init_value'] in domain_values:
row['init_index'] = domain_values.index(row['init_value'])
# update weak label index based on new domain
if row['weak_label'] != NULL_REPR:
row['weak_label_idx'] = domain_values.index(row['weak_label'])

weak_label, weak_label_prob = max(preds, key=lambda pred: pred[1])

Expand Down Expand Up @@ -354,7 +371,7 @@ def generate_domain(self):
def get_domain_cell(self, attr, row):
"""
get_domain_cell returns a list of all domain values for the given
entity (row) and attribute.
entity (row) and attribute. The domain never has null as a possible value.
We define domain values as values in 'attr' that co-occur with values
in attributes ('cond_attr') that are correlated with 'attr' at least in
Expand All @@ -374,10 +391,10 @@ def get_domain_cell(self, attr, row):
"""

domain = set()
init_value = row[attr]
correlated_attributes = self.get_corr_attributes(attr, self.cor_strength)
# Iterate through all attributes correlated at least self.cor_strength ('cond_attr')
# and take the top K co-occurrence values for 'attr' with the current
# row's 'cond_attr' value.
# Iterate through all correlated attributes and take the top K co-occurrence values
# for 'attr' with the current row's 'cond_attr' value.
for cond_attr in correlated_attributes:
# Ignore correlations with index, tuple id or the same attribute.
if cond_attr == attr or cond_attr == '_tid_':
Expand All @@ -386,34 +403,35 @@ def get_domain_cell(self, attr, row):
logging.warning("domain generation could not find pair_statistics between attributes: {}, {}".format(cond_attr, attr))
continue
cond_val = row[cond_attr]
# Ignore correlations with null values.
if cond_val == '_nan_':
# Ignore co-occurrence with a NULL cond init value since we do not
# store them.
# Also it does not make sense to retrieve the top co-occuring
# values with a NULL value.
# It is possible for cond_val to not be in pair stats if it only co-occurs
# with NULL values.
if cond_val == NULL_REPR or cond_val not in self.pair_stats[cond_attr][attr]:
continue
s = self.pair_stats[cond_attr][attr]
try:
candidates = s[cond_val]
domain.update(candidates)
except KeyError as missing_val:
if row[attr] != '_nan_':
# Error since co-occurrence must be at least 1 (since
# the current row counts as one co-occurrence).
logging.error('value missing from statistics: {}'.format(missing_val))
raise

# Add the initial value to the domain.
init_value = row[attr]
domain.add(init_value)

# Remove _nan_ if added due to correlated attributes, only if it was not the initial value.
if init_value != '_nan_':
domain.discard('_nan_')
# Update domain with top co-occuring values with the cond init value.
candidates = self.pair_stats[cond_attr][attr][cond_val]
domain.update(candidates)

# We should not have any NULLs since we do not store co-occurring NULL
# values.
assert NULL_REPR not in domain

# Add the initial value to the domain if it is not NULL.
if init_value != NULL_REPR:
domain.add(init_value)

# Convert to ordered list to preserve order.
domain_lst = sorted(list(domain))

# Get the index of the initial value. This should never raise a ValueError since we made sure
# that 'init_value' was added.
init_value_idx = domain_lst.index(init_value)
# Get the index of the initial value.
# NULL values are not in the domain so we set their index to -1.
init_value_idx = -1
if init_value != NULL_REPR:
init_value_idx = domain_lst.index(init_value)

return init_value, init_value_idx, domain_lst

Expand All @@ -423,6 +441,9 @@ def get_random_domain(self, attr, cur_value):
'self.max_sample' of domain values for 'attr' that is NOT 'cur_value'.
"""
domain_pool = set(self.single_stats[attr].keys())
# We should not have any NULLs since we do not keep track of their
# counts.
assert NULL_REPR not in domain_pool
domain_pool.discard(cur_value)
domain_pool = sorted(list(domain_pool))
size = len(domain_pool)
Expand Down
6 changes: 1 addition & 5 deletions domain/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,10 @@ def predict_pp(self, row, attr, values):
raise NotImplementedError

@abstractmethod
def predict_pp_batch(self, raw_records_by_tid, cell_domain_rows):
def predict_pp_batch(self):
"""
predict_pp_batch is like predict_pp but with a batch of cells.
:param raw_records_by_tid: (dict) maps TID to its corresponding row (record) in the raw data
:param cell_domain_rows: (list[pd.record]) list of records from the cell domain DF. Each
record should include the field '_tid_', 'attribute', and 'domain'
:return: iterator of iterator of tuples (value, proba) (one iterator per cell/row in cell_domain_rows)
"""
raise NotImplementedError
Loading

0 comments on commit 2d43f34

Please sign in to comment.