Skip to content

Commit

Permalink
Fixed some things in domain.
Browse files Browse the repository at this point in the history
  • Loading branch information
richardwu committed Jun 22, 2019
1 parent 2303c31 commit f0805ef
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 19 deletions.
6 changes: 4 additions & 2 deletions domain/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,10 @@ def generate_domain(self):
# point to run inference on it since we cannot even generate
# a random domain. Therefore, we just ignore it from the
# final tensor.
# if len(rand_dom_values) == 0:
# continue
# We do not drop NULL cells since we stil have to repair them
# with their 1 domain value.
if init_value != NULL_REPR and len(rand_dom_values) == 0:
continue

# Otherwise, just add the random domain values to the domain
# and set the cell status accordingly.
Expand Down
16 changes: 0 additions & 16 deletions domain/estimators/tuple_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,18 +642,6 @@ def __init__(self, env, dataset, domain_df,
max(list(map(len, numerical_attr_groups)) or [0]),
self._embed_size)
raise Exception()
# Convert non numerical init values in numerical attributes with _nan_.
# if self._numerical_attrs is not None:
# fil_attr = self.domain_df['attribute'].isin(self._numerical_attrs)
# fil_notnull = self.domain_df['weak_label'] != NULL_REPR
# fil_notnumeric = self.domain_df['weak_label'].str.contains(NONNUMERICS)
# bad_numerics = fil_attr & fil_notnull & fil_notnumeric
# if bad_numerics.sum():
# self.domain_df.loc[bad_numerics, 'weak_label'] = NULL_REPR
# logging.warning('%s: replaced %d non-numerical values in DOMAIN as "%s" (NULL)',
# type(self).__name__,
# bad_numerics.sum(),
# NULL_REPR)
# Remove domain for numerical attributes.
fil_numattr = self.domain_df['attribute'].isin(self._numerical_attrs)

Expand Down Expand Up @@ -739,9 +727,6 @@ def __init__(self, env, dataset, domain_df,
self.in_num_w1 = torch.nn.Parameter(torch.zeros(self._n_num_attr_groups, self._embed_size, self._embed_size))
self.in_num_bias1 = torch.nn.Parameter(torch.zeros(self._n_num_attr_groups, self._embed_size))

# out_num_zeros_vecs may not be necessary
self.out_num_zero_vecs = torch.nn.Parameter(torch.zeros(self._n_train_num_attrs, self._embed_size))

self.out_num_bases = torch.nn.Parameter(torch.zeros(self._n_train_num_attrs, self._embed_size, self._max_num_dim))
# Non-linearity for combined_init for each numerical attr
self.out_num_w1 = torch.nn.Parameter(torch.zeros(self._n_train_num_attrs, self._embed_size, self._embed_size))
Expand Down Expand Up @@ -783,7 +768,6 @@ def __init__(self, env, dataset, domain_df,
torch.nn.init.xavier_uniform_(self.in_num_bias1)

if self._n_train_num_attrs > 0:
torch.nn.init.xavier_uniform_(self.out_num_zero_vecs)
torch.nn.init.xavier_uniform_(self.out_num_bases)
torch.nn.init.xavier_uniform_(self.out_num_w1)
torch.nn.init.xavier_uniform_(self.out_num_bias1)
Expand Down
2 changes: 1 addition & 1 deletion tests/start_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ source ../set_env.sh

# Launch tests.
echo "Launching tests..."
pytest -n 1
pytest

0 comments on commit f0805ef

Please sign in to comment.