Skip to content

Commit

Permalink
Merge pull request #10 from joshchang/refactor
Browse files Browse the repository at this point in the history
Refactor
  • Loading branch information
jcalifornia authored Feb 18, 2021
2 parents 6ff9127 + 1a6e697 commit 603f3d6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions autoencirt/irt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def calibrate_advi(
_data = data.batch(batch_size, drop_remainder=True)
# data = data.batch

_data = _data.prefetch(2)
_data = _data.prefetch(tf.data.experimental.AUTOTUNE)

def run_approximation(num_epochs):
losses = fit_surrogate_posterior(
Expand Down Expand Up @@ -193,14 +193,14 @@ def calibrate_mcmc(self, data=None, num_steps=1000, burnin=500,
if card is None:
card = tf_data_cardinality(data)
_data = data.batch(int(card/10))
_data = _data.prefetch(2)
_data = _data.prefetch(tf.data.experimental.AUTOTUNE)

@tf.function
def energy(*x):
energy = 0
for batch in iter(_data):
energy += self.unormalized_log_prob_list(batch, x)
return energy


samples, sampler_stat = run_chain(
init_state=initial_list,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setuptools.setup(
name="autoencirt", # Replace with your own username
version="0.0.4",
version="0.0.5",
author="Josh Chang",
author_email="[email protected]",
description="",
Expand Down

0 comments on commit 603f3d6

Please sign in to comment.