diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7056319 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python.pythonPath": "/home/liux3941/miniconda3/envs/deepivMY/bin/python", + "jupyter.jupyterServerType": "local" +} \ No newline at end of file diff --git a/README.rst b/README.rst index c9e72cc..6b17784 100644 --- a/README.rst +++ b/README.rst @@ -18,7 +18,7 @@ DeepIV :alt: Updates -IMPORTANT: Newer versions of Keras have broken this implementation. This code currently only support Keras 2.0.6 (which is what will be installed if you use the pip install instructions described below). +IMPORTANT: Newer versions of Keras have broken this implementation. This code currently supports Keras 2.3.1, and the tensorflow version is 2.5. See details in ``setup.py``. A package for counterfactual prediction using deep instrument variable methods that builds on Keras_. diff --git a/deepiv/architectures.py b/deepiv/architectures.py index 46bed1e..306a04b 100644 --- a/deepiv/architectures.py +++ b/deepiv/architectures.py @@ -1,16 +1,17 @@ from __future__ import absolute_import, division, print_function, unicode_literals - -import keras -import keras.backend as K -from keras.layers import (Convolution2D, Dense, Dropout, Flatten, - MaxPooling2D) -from keras.models import Sequential -from keras.regularizers import l2 -from keras.constraints import maxnorm -from keras.utils import np_utils - +import tensorflow as tf +import tensorflow.keras as keras +import tensorflow.keras.backend as K +from tensorflow.keras.layers import (Convolution2D, Dense, Dropout, Flatten, + MaxPooling2D) +from tensorflow.keras.models import Sequential +from tensorflow.keras.regularizers import l2 +from tensorflow.keras.constraints import MaxNorm as maxnorm +#from tensorflow.keras.utils import np_utils +from tensorflow.python.keras import utils as np_utils import numpy as np + def feed_forward_net(input, output, hidden_layers=[64, 64], activations='relu', dropout_rate=0., l2=0., constrain_norm=False): ''' @@ -26,18 +27,19 @@ def feed_forward_net(input, output, hidden_layers=[64, 64], activations='relu', state = input if isinstance(activations, str): activations = [activations] * len(hidden_layers) - + for h, a in zip(hidden_layers, activations): if l2 > 0.: - w_reg = keras.regularizers.l2(l2) + w_reg = tf.keras.regularizers.l2(l2) else: w_reg = None - const = maxnorm(2) if constrain_norm else None + const = maxnorm(2) if constrain_norm else None state = Dense(h, activation=a, kernel_regularizer=w_reg, kernel_constraint=const)(state) if dropout_rate > 0.: state = Dropout(dropout_rate)(state) return output(state) + def convnet(input, output, dropout_rate=0., input_shape=(1, 28, 28), batch_size=100, l2_rate=0.001, nb_epoch=12, img_rows=28, img_cols=28, nb_filters=64, pool_size=(2, 2), kernel_size=(3, 3), activations='relu', constrain_norm=False): @@ -73,6 +75,7 @@ def convnet(input, output, dropout_rate=0., input_shape=(1, 28, 28), batch_size= state = Dropout(dropout_rate)(state) return output(state) + def feature_to_image(features, height=28, width=28, channels=1, backend=K): ''' Reshape a flattened image to the input format for convolutions. @@ -86,4 +89,3 @@ def feature_to_image(features, height=28, width=28, channels=1, backend=K): return backend.reshape(features, (-1, channels, height, width)) else: return backend.reshape(features, (-1, height, width, channels)) - diff --git a/deepiv/custom_gradients.py b/deepiv/custom_gradients.py index 47fa5c3..f859fab 100644 --- a/deepiv/custom_gradients.py +++ b/deepiv/custom_gradients.py @@ -1,26 +1,33 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import types + +import tensorflow.keras +from tensorflow.keras import backend as K + -import keras -from keras import backend as K if K.backend() == "theano": import theano.tensor as tensor Lop = tensor.Lop elif K.backend() == "tensorflow": import tensorflow as tf - def Lop(output, wrt, eval_points): - grads = tf.gradients(output, wrt, grad_ys=eval_points) - return grads -import types + + +def Lop(output, wrt, eval_points): + grads = tf.gradients(output, wrt, grad_ys=eval_points) + return grads + # Used to modify the default keras Optimizer object to allow # for custom gradient computation. + def get_gradients(self, loss, params): ''' Replacement for the default keras get_gradients() function. Modification: checks if the object has the attribute grads and returns that rather than calculating the gradients using automatic differentiation. + In keras, it is gradients = K.gradients(outputTensor, listOfVariableTensors) ''' if hasattr(self, 'grads'): grads = self.grads @@ -33,40 +40,122 @@ def get_gradients(self, loss, params): grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] return grads -def replace_gradients_mse(model, opt, batch_size, n_samples = 1): + +def replace_gradients_mse(model, opt, batch_size, n_samples=1): ''' Replace the gradients of a Keras model with mean square error loss. + # + # TODO: check model components, only work with py2.7 ''' # targets has been repeated twice so the below creates two identical columns # of the target values - we'll only use the first column. targets = K.reshape(model.targets[0], (batch_size, n_samples * 2)) - output = K.mean(K.reshape(model.outputs[0], (batch_size, n_samples, 2)), axis=1) + output = K.mean(K.reshape(model.outputs[0], (batch_size, n_samples, 2)), axis=1) # compute d Loss / d output - dL_dOutput = (output[:,0] - targets[:,0]) * (2.) / batch_size + dL_dOutput = (output[:, 0] - targets[:, 0]) * (2.) / batch_size # compute (d Loss / d output) (d output / d theta) for each theta trainable_weights = model.trainable_weights - grads = Lop(output[:,1], wrt=trainable_weights, eval_points=dL_dOutput) + # grads = tf.gradients(output, wrt, grad_ys=eval_points) + grads = Lop(output[:, 1], wrt=trainable_weights, eval_points=dL_dOutput) # compute regularizer gradients # add loss with respect to regularizers reg_loss = model.total_loss * 0. for r in model.losses: - reg_loss += r + reg_loss += r reg_grads = K.gradients(reg_loss, trainable_weights) - grads = [g+r for g,r in zip(grads, reg_grads)] - + grads = [g+r for g, r in zip(grads, reg_grads)] + opt = keras.optimizers.get(opt) # Patch keras gradient calculation to allow for user defined gradients - opt.get_gradients = types.MethodType( get_gradients, opt ) + opt.get_gradients = types.MethodType(get_gradients, opt) opt.grads = grads model.optimizer = opt return model + +def custom_mse_unbiased_gradients(model, y_true, y_pred): + """ + in the unbiased case, we sample two independent samples each time, and y_ture has already been repeated 2 times, + """ + (batch_size, n_samples) = y_true.shape + batch_size //= 2 + + targets = K.reshape(y_true, (batch_size, n_samples * 2)) + output = K.mean(K.reshape(y_pred, (batch_size, n_samples, 2)), axis=1) + targets = tf.cast(targets, dtype=output.dtype) + # compute d Loss / d output + dL_dOutput = (output[:, 0] - targets[:, 0]) * (2.) / batch_size + # compute (d Loss / d output) (d output / d theta) for each theta + trainable_weights = model.trainable_weights + # grads = tf.gradients(output, wrt, grad_ys=eval_points) + + grads = Lop(output[:, 1], wrt=trainable_weights, eval_points=dL_dOutput) + # compute regularizer gradients + + with tf.GradientTape() as tape2: + # add loss with respect to regularizers + #reg_loss = model.total_loss * 0. + reg_loss = 0. + for r in model.losses: + reg_loss += r + + reg_grads = tape.gradient(reg_loss, trainable_weights) + grads = [g+r for g, r in zip(grads, reg_grads)] + + opt = keras.optimizers.get(opt) + # Patch keras gradient calculation to allow for user defined gradients + opt.get_gradients = types.MethodType(get_gradients, opt) + opt.grads = grads + model.optimizer = opt + return model + + def build_mc_mse_loss(n_samples): - def mc_mse(y_true, y_predicted): - n_examples = y_true.shape[0] / n_samples / 2 - targets = y_true.reshape((n_examples , n_samples * 2)) - output = y_predicted.reshape((n_examples, n_samples * 2)).mean(axis=1) - return K.mean(K.square(targets[:,0] - output)) + """ + return MC mse loss function + """ + def mc_mse(y_true, y_pred): + n_examples = y_true.shape[0] / n_samples / 2 + targets = y_true.reshape((n_examples, n_samples * 2)) + output = y_pred.reshape((n_examples, n_samples, 2)).mean(axis=1) + return K.mean(K.square(targets[:, 0] - output)) return mc_mse + +def unbiased_mse_loss_and_gradients(model, y_true, y_pred, batch_size, n_samples=1): + """ + In custom loss function, ytrue and y_pred need to be tensor with same dtype + n_samples is B in equattion (10) + """ + + # total_size = y_pred.shape[0] + # batch_size = total_size//n_samples//2 + targets = K.reshape(y_true, (batch_size, n_samples*2)) + output = K.mean(K.reshape(y_pred, (batch_size, n_samples, 2)), axis=1) + targets = tf.cast(targets, dtype=output.dtype) + + # compute d Loss / d output + dL_dOutput = (output[:, 0] - targets[:, 0]) * (2.) / batch_size + # compute (d Loss / d output) (d output / d theta) for each theta + trainable_weights = model.trainable_weights + grads = tf.gradients(output[:, 1], trainable_weights, grad_ys=dL_dOutput) + + # # add loss with respect to regularizers + # reg_loss = 0. + # for r in model.losses: + # reg_loss += r + # reg_grads = K.gradients(reg_loss, trainable_weights) + + # grads = [g+r for g, r in zip(grads, reg_grads)] + + # opt = tensorflow.keras.optimizers.get(optimizer) + # opt.apply_gradients(zip(grads, trainable_weights)) + # Patch keras gradient calculation to allow for user defined gradients + # opt.get_gradients = types.MethodType(get_gradients, opt) + # opt.grads = grads + # model.optimizer = opt + + # loss = tf.math.multiply(output[:, 1] - targets[:, 1], output[:, 0] - targets[:, 0]) + + return grads diff --git a/deepiv/densities.py b/deepiv/densities.py index 74fd631..b041bb2 100644 --- a/deepiv/densities.py +++ b/deepiv/densities.py @@ -1,12 +1,12 @@ from __future__ import absolute_import, division, print_function, unicode_literals import numpy -import keras -from keras import backend as K +import tensorflow.keras as keras +from tensorflow.keras import backend as K -from keras.layers.merge import Concatenate -from keras.layers import Lambda -from keras.layers.core import Reshape +from tensorflow.keras.layers import Concatenate +from tensorflow.keras.layers import Lambda +from tensorflow.keras.layers import Reshape def split(start, stop): return Lambda(lambda x: x[:, start:stop], output_shape=(None, stop-start)) diff --git a/deepiv/models.py b/deepiv/models.py index 157e813..c0b8ebe 100644 --- a/deepiv/models.py +++ b/deepiv/models.py @@ -6,25 +6,28 @@ import deepiv.samplers as samplers import deepiv.densities as densities -from deepiv.custom_gradients import replace_gradients_mse - -from keras.models import Model -from keras import backend as K -from keras.layers import Lambda, InputLayer -from keras.engine import topology +from deepiv.custom_gradients import * +import tensorflow as tf +from tensorflow.keras.models import Model +import tensorflow.keras.metrics as Metrics + +from tensorflow.keras import backend as K +from tensorflow.keras.layers import Lambda, InputLayer +# from tensorflow.keras.engine import topology try: import h5py except ImportError: h5py = None -import keras.utils +import tensorflow.keras.utils import numpy from sklearn import linear_model from sklearn.decomposition import PCA from scipy.stats import norm + class Treatment(Model): ''' Adds sampling functionality to a Keras model and extends the losses to support @@ -64,6 +67,7 @@ def sample_binomial(inputs, use_dropout=False): elif loss in ["mean_absolute_error", "mae", "MAE"]: output += samplers.random_laplace(K.shape(output), mu=0.0, b=1.0) draw_sample = K.function(inputs + [K.learning_phase()], [output]) + def sample_laplace(inputs, use_dropout=False): ''' Helper to draw samples from a Laplacian distribution @@ -74,9 +78,27 @@ def sample_laplace(inputs, use_dropout=False): elif loss == "mixture_of_gaussians": pi, mu, log_sig = densities.split_mixture_of_gaussians(output, self.n_components) - samples = samplers.random_gmm(pi, mu, K.exp(log_sig)) - draw_sample = K.function(inputs + [K.learning_phase()], [samples]) - return lambda inputs, use_dropout: draw_sample(inputs + [int(use_dropout)])[0] + samples = samplers.random_gmm(pi, mu, K.exp(log_sig)) # samples shape None + + draw_sample = Model(inputs=inputs, outputs=[samples]) # symbolic_learning_phase + + def sample_gmm(inputs, use_dropout=False): + ''' + Another option is: + >>> draw_sample = K.function(inputs, [samples]) + >>> with K.set_learning_phase(int(use_dropout)): + ... return draw_sample(inputs + [int(use_dropout)])[0] + + For the follow `draw_sample`, the return is not a list anymore, so I no longer use `outs[0]`. + if you want to use the list, please look it up in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py (line: 4109). + ''' + outs = draw_sample(inputs, training=use_dropout) + return outs + + # I didn't use the lambda function + # return lambda inputs, use_dropout: draw_sample(inputs + [int(use_dropout)])[0] + + return sample_gmm else: raise NotImplementedError("Unrecognised loss: %s.\ @@ -107,9 +129,15 @@ def compile(self, optimizer, loss, metrics=None, loss_weights=None, supply n_components argument") self.n_components = n_components self._prepare_sampler(loss) - loss = lambda y_true, y_pred: densities.mixture_of_gaussian_loss(y_true, - y_pred, - n_components) + + def loss(y_true, y_pred): + """ + customized loss function returns a scalar for each data-point and takes the following two arguments: + y_true: True labels. TensorFlow/Theano tensor. + y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true. + """ + return densities.mixture_of_gaussian_loss(y_true, y_pred, + n_components) def predict_mean(x, batch_size=32, verbose=0): ''' @@ -139,41 +167,117 @@ def sample(self, inputs, n_samples=1, use_dropout=False): else: raise Exception("Compile model with loss before sampling") + +@tf.custom_gradient +def custom_op(model, targets, output, batch_size, n_samples): + result = ... # do forward computation + + def custom_grad(targets, output, batch_size, n_samples): + # grad = ... # compute gradient + dL_dOutput = (output[:, 0] - targets[:, 0]) * (2.) / batch_size + # compute (d Loss / d output) (d output / d theta) for each theta + trainable_weights = model.trainable_weights + grads = tf.gradients(output[:, 1], trainable_weights, grad_ys=dL_dOutput) + #grads = Lop(output[:, 1], wrt=trainable_weights, eval_points=dL_dOutput) + return grads + return result, custom_grad + + class Response(Model): ''' Extends the Keras Model class to support sampling from the Treatment model during training. - Overwrites the existing fit_generator function. + Overwrites the existing fit_generator function. where + -> object : the Keras Object model. + -> generator : a generator whose output must be a list of the form: + - (inputs, targets) + - (input, targets, sample_weights) # Arguments In addition to the standard model arguments, a Response object takes a Treatment object as input so that it can sample from the fitted treatment distriubtion during training. ''' + def __init__(self, treatment, **kwargs): + super(Response, self).__init__(**kwargs) if isinstance(treatment, Treatment): self.treatment = treatment else: raise TypeError("Expected a treatment model of type Treatment. \ Got a model of type %s. Remember to train your\ treatment model first." % type(treatment)) - super(Response, self).__init__(**kwargs) - def compile(self, optimizer, loss, metrics=None, loss_weights=None, sample_weight_mode=None, - unbiased_gradient=False,n_samples=1, batch_size=None): + def compile_old(self, optimizer, loss, metrics=None, loss_weights=None, sample_weight_mode=None, + unbiased_gradient=False, n_samples=1, batch_size=None): super(Response, self).compile(optimizer=optimizer, loss=loss, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode) self.unbiased_gradient = unbiased_gradient if unbiased_gradient: if loss in ["MSE", "mse", "mean_squared_error"]: if batch_size is None: - raise ValueError("Must supply a batch_size argument if using unbiased gradients. Currently batch_size is None.") + raise ValueError( + "Must supply a batch_size argument if using unbiased gradients. Currently batch_size is None.") replace_gradients_mse(self, optimizer, batch_size=batch_size, n_samples=n_samples) else: warnings.warn("Unbiased gradient only implemented for mean square error loss. It is unnecessary for\ logistic losses and currently not implemented for absolute error losses.") - + + def train_step(self, data): + + x, y_true = data + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # Forward pass + # Compute the loss value + # (the loss function is configured in `compile()`) + loss = self.compiled_loss(y_true, y_pred, regularization_losses=self.losses) + trainable_vars = self.trainable_variables + if self.unbiased_gradient: + grads = unbiased_mse_loss_and_gradients( + self, y_true, y_pred, self.batch_size, self.n_samples) + else: + grads = tape.gradient(loss, trainable_vars) + self.optimizer.apply_gradients(zip(grads, trainable_vars)) + self.compiled_metrics.update_state(y_true, y_pred) + # Return a dict mapping metric names to current value + return {m.name: m.result() for m in self.metrics} + + def compile(self, optimizer, loss, metrics=None, loss_weights=None, sample_weight_mode=None, + unbiased_gradient=False, n_samples=1, batch_size=None): + # super(Response, self).compile(optimizer=optimizer, loss=loss, loss_weights=loss_weights, + # sample_weight_mode=sample_weight_mode) + + self.unbiased_gradient = unbiased_gradient + self.n_samples = n_samples + if loss in ["MSE", "mse", "mean_squared_error"]: + metrics = [Metrics.MeanSquaredError(name="mse")] + # if unbiased_gradient: + # if loss in ["MSE", "mse", "mean_squared_error"]: + # if batch_size is None: + # raise ValueError( + # "Must supply a batch_size argument if using unbiased gradients. Currently batch_size is None.") + # self._train_step(data) + + # # def unbiased_loss(y_true, y_pred): + # # return "mse", unbiased_mse_loss_and_gradients( + # # self, optimizer, y_true, y_pred, batch_size, n_samples=1) + # # def unbiased_loss_grad_opt(y_true, y_pred): + # # return unbiased_mse_loss_and_gradients(self, optimizer, y_true, y_pred, batch_size, n_samples=1) + # # # return unbiased_mse_loss_and_gradients(self, y_true, y_pred, batch_size, n_samples) + # # super(Response, self).compile(optimizer=optimizer, loss=loss, loss_weights=loss_weights, + # # sample_weight_mode=sample_weight_mode) + + # # unbiased_mse_gradients(self, y_true, y_pred, batch_size, n_samples=1) + # # replace_gradients_mse(self, optimizer, batch_size=batch_size, n_samples=n_samples) + # else: + # warnings.warn("Unbiased gradient only implemented for mean square error loss. It is unnecessary for\ + # logistic losses and currently not implemented for absolute error losses.") + super(Response, self).compile(optimizer=optimizer, + loss=loss, + loss_weights=loss_weights, + metrics=metrics, + sample_weight_mode=sample_weight_mode) def fit(self, x=None, y=None, batch_size=512, epochs=1, verbose=1, callbacks=None, validation_data=None, class_weight=None, initial_epoch=0, samples_per_batch=None, @@ -187,6 +291,7 @@ def fit(self, x=None, y=None, batch_size=512, epochs=1, verbose=1, callbacks=Non The remainder of the arguments correspond to the Keras definitions. ''' batch_size = numpy.minimum(y.shape[0], batch_size) + self.batch_size = batch_size if seed is None: seed = numpy.random.randint(0, 1e6) if samples_per_batch is None: @@ -200,24 +305,16 @@ def fit(self, x=None, y=None, batch_size=512, epochs=1, verbose=1, callbacks=Non else: generator = OnesidedUnbaised(x[1:], x[0], y, observed_treatments, batch_size, self.treatment.sample, samples_per_batch) - - steps_per_epoch = y.shape[0] // batch_size - super(Response, self).fit_generator(generator=generator, - steps_per_epoch=steps_per_epoch, - epochs=epochs, verbose=verbose, - callbacks=callbacks, validation_data=validation_data, - class_weight=class_weight, initial_epoch=initial_epoch) - - def fit_generator(self, **kwargs): - ''' - We use override fit_generator to support sampling from the treatment model during training. - If you need this functionality, you'll need to build a generator that samples from the - treatment and performs whatever transformations you're performing. Please submit a pull - request if you implement this. - ''' - raise NotImplementedError("We use override fit_generator to support sampling from the\ - treatment model during training.") + steps_per_epoch = int(y.shape[0] // batch_size) + + super(Response, self).fit(generator, + steps_per_epoch=steps_per_epoch, + epochs=epochs, verbose=verbose, + # callbacks=callbacks, + validation_data=validation_data, + class_weight=class_weight, + initial_epoch=initial_epoch) def expected_representation(self, x, z, n_samples=100, batch_size=None, seed=None): inputs = [z, x] @@ -230,7 +327,7 @@ def expected_representation(self, x, z, n_samples=100, batch_size=None, seed=Non intermediate_layer_model = Model(inputs=self.inputs, outputs=self.layers[-2].output) - + def pred(inputs, n_samples=100, seed=None): features = inputs[1] @@ -245,7 +342,7 @@ def pred(inputs, n_samples=100, seed=None): def conditional_representation(self, x, p): inputs = [x, p] - if not hasattr(self, "_c_representation"): + if not hasattr(self, "_c_representation"): intermediate_layer_model = Model(inputs=self.inputs, outputs=self.layers[-2].output) @@ -260,11 +357,11 @@ def dropout_predict(self, x, z, n_samples=100): else: inputs = [z, x] if not hasattr(self, "_dropout_predict"): - + predict_with_dropout = K.function(self.inputs + [K.learning_phase()], [self.layers[-1].output]) - def pred(inputs, n_samples = 100): + def pred(inputs, n_samples=100): # draw samples from the treatment network with dropout turned on samples = self.treatment.sample(inputs, n_samples, use_dropout=True) # prepare inputs for the response network @@ -292,14 +389,13 @@ def credible_interval(self, x, z, n_samples=100, p=0.95): def _add_constant(self, X): return numpy.concatenate((numpy.ones((X.shape[0], 1)), X), axis=1) - + def predict_confidence(self, x, p): if hasattr(self, "_predict_confidence"): return self._predict_confidence(x, p) else: raise Exception("Call fit_confidence_interval before running predict_confidence") - def fit_confidence_interval(self, x_lo, z_lo, p_lo, y_lo, n_samples=100, alpha=0.): eta_bar = self.expected_representation(x=x_lo, z=z_lo, n_samples=n_samples) pca = PCA(1-1e-16, svd_solver="full", whiten=True) @@ -322,17 +418,15 @@ def fit_confidence_interval(self, x_lo, z_lo, p_lo, y_lo, n_samples=100, alpha=0 V = numpy.dot(numpy.dot(hhi, heh), hhi) def pred(xx, pp): - H = self._add_constant(pca.transform(self.conditional_representation(xx,pp))) + H = self._add_constant(pca.transform(self.conditional_representation(xx, pp))) sdhb = numpy.sqrt(numpy.diag(numpy.dot(numpy.dot(H, V), H.T))) hb = ols2.predict(H).flatten() return hb, sdhb - - self._predict_confidence = pred - + self._predict_confidence = pred -class SampledSequence(keras.utils.Sequence): +class SampledSequence(tensorflow.keras.utils.Sequence): def __init__(self, features, instruments, outputs, batch_size, sampler, n_samples=1, seed=None): self.rng = numpy.random.RandomState(seed) if not isinstance(features, list): @@ -359,21 +453,23 @@ def __len__(self): def shuffle(self): idx = self.rng.permutation(numpy.arange(self.instruments.shape[0])) - self.instruments = self.instruments[idx,:] - self.outputs = self.outputs[idx,:] - self.features = [f[idx,:] for f in self.features] - - def __getitem__(self,idx): + self.instruments = self.instruments[idx, :] + self.outputs = self.outputs[idx, :] + self.features = [f[idx, :] for f in self.features] + + def __getitem__(self, idx): instruments = [self.instruments[idx*self.batch_size:(idx+1)*self.batch_size, :]] features = [inp[idx*self.batch_size:(idx+1)*self.batch_size, :] for inp in self.features] sampler_input = instruments + features samples = self.sampler(sampler_input, self.n_samples) - batch_features = [f[idx*self.batch_size:(idx+1)*self.batch_size].repeat(self.n_samples, axis=0) for f in self.features] + [samples] + batch_features = [f[idx*self.batch_size:(idx+1)*self.batch_size].repeat(self.n_samples, axis=0) + for f in self.features] + [samples] batch_y = self.outputs[idx*self.batch_size:(idx+1)*self.batch_size].repeat(self.n_samples, axis=0) if idx == (len(self) - 1): self.shuffle() return batch_features, batch_y + class OnesidedUnbaised(SampledSequence): def __init__(self, features, instruments, outputs, treatments, batch_size, sampler, n_samples=1, seed=None): self.rng = numpy.random.RandomState(seed) @@ -393,10 +489,10 @@ def __init__(self, features, instruments, outputs, treatments, batch_size, sampl def shuffle(self): idx = self.rng.permutation(numpy.arange(self.instruments.shape[0])) - self.instruments = self.instruments[idx,:] - self.outputs = self.outputs[idx,:] - self.features = [f[idx,:] for f in self.features] - self.treatments = self.treatments[idx,:] + self.instruments = self.instruments[idx, :] + self.outputs = self.outputs[idx, :] + self.features = [f[idx, :] for f in self.features] + self.treatments = self.treatments[idx, :] def __getitem__(self, idx): instruments = [self.instruments[idx*self.batch_size:(idx+1)*self.batch_size, :]] @@ -405,19 +501,20 @@ def __getitem__(self, idx): sampler_input = instruments + features samples = self.sampler(sampler_input, self.n_samples // 2) samples = numpy.concatenate([observed_treatments, samples], axis=0) - batch_features = [f[idx*self.batch_size:(idx+1)*self.batch_size].repeat(self.n_samples, axis=0) for f in self.features] + [samples] + batch_features = [f[idx*self.batch_size:(idx+1)*self.batch_size].repeat(self.n_samples, axis=0) + for f in self.features] + [samples] batch_y = self.outputs[idx*self.batch_size:(idx+1)*self.batch_size].repeat(self.n_samples, axis=0) if idx == (len(self) - 1): self.shuffle() return batch_features, batch_y + def load_weights(filepath, model): if h5py is None: raise ImportError('`load_weights` requires h5py.') - with h5py.File(filepath, mode='r') as f: + # with h5py.File(filepath, mode='r') as f: # set weights - topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) - + # topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) + model.load_weights(filepath) return model - diff --git a/deepiv/samplers.py b/deepiv/samplers.py index 0f39526..930ed23 100644 --- a/deepiv/samplers.py +++ b/deepiv/samplers.py @@ -1,8 +1,9 @@ from __future__ import absolute_import, division, print_function, unicode_literals +import tensorflow as tf import numpy -from keras import backend as K -from keras.engine.topology import InputLayer +from tensorflow.keras import backend as K +# from tensorflow.keras.layers import InputLayer if K.backend() == "theano": from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams @@ -10,18 +11,22 @@ else: import tensorflow as tf + def random_laplace(shape, mu=0., b=1.): ''' Draw random samples from a Laplace distriubtion. See: https://en.wikipedia.org/wiki/Laplace_distribution#Generating_random_variables_according_to_the_Laplace_distribution ''' - U = K.random_uniform(shape, -0.5, 0.5) + U = K.random_uniform(shape, -0.5, 0.5) # alias to return mu - b * K.sign(U) * K.log(1 - 2 * K.abs(U)) + def random_normal(shape, mean=0.0, std=1.0): + # Returns: A tensor with normal distribution of values. return K.random_normal(shape, mean, std) + def random_multinomial(logits, seed=None): ''' Theano function for sampling from a multinomal with probability given by `logits` @@ -32,8 +37,12 @@ def random_multinomial(logits, seed=None): rng = RandomStreams(seed=seed) return rng.multinomial(n=1, pvals=logits, ndim=None, dtype=_FLOATX) elif K.backend() == "tensorflow": - return tf.one_hot(tf.squeeze(tf.multinomial(K.log(logits), num_samples=1)), - int(logits.shape[1])) + samples_multi = tf.random.categorical(logits=K.log(logits), num_samples=1) + #samples_multi = tf.compat.v1.multinomial(logits=K.log(logits), num_samples=1) + # smples_multi = tfp.distributions.Multinomial(total_count=1,logits=K.log(logits)).sample(1) + sample_squeeze = tf.squeeze(samples_multi) + return tf.one_hot(sample_squeeze, int(logits.shape[1])) + def random_gmm(pi, mu, sig): ''' @@ -42,8 +51,6 @@ def random_gmm(pi, mu, sig): the matrices n times if you want to get n samples), but makes it easy to implment code where the parameters vary as they are conditioned on different datapoints. ''' - normals = random_normal(K.shape(mu), mu, sig) - k = random_multinomial(pi) + normals = random_normal(K.shape(mu), mu, sig) # [None,n_components] + k = random_multinomial(pi) # shape None return K.sum(normals * k, axis=1, keepdims=True) - - diff --git a/experiments/data_generator.py b/experiments/data_generator.py index 33fd032..ebf3b0e 100644 --- a/experiments/data_generator.py +++ b/experiments/data_generator.py @@ -1,34 +1,75 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import numpy as np -#from causenet.datastream import DataStream, prepare_datastream +import numpy as np +# from causenet.datastream import DataStream, prepare_datastream from sklearn.preprocessing import OneHotEncoder - +import os X_mnist = None y_mnist = None -def monte_carlo_error(g_hat, data_fn, ntest=5000, has_latent=False, debug=False): - seed = np.random.randint(1e9) +# def training_loss(x,z,y_hat,y_true): +# return ((y_hat - y_true)**2).mean() + + +def load_test_data(test_fn=None, data_fn=None, ntest=5000, ypcor=0.5, ynoise=1.): + #seed = np.random.randint(1e9) + seed = 1234 + if test_fn: + test_data = np.load(test_fn) + x, z, t, y = test_data['x'], test_data['z'], test_data['t'], test_data['y_true'] + else: + + assert data_fn + x, z, t, y, g_true = data_fn(ntest, seed, test=True, ypcor=ypcor, ynoise=ynoise) + # re-draw to get new independent treatment and implied response, because we care about the change of price's effects + t = np.linspace(np.percentile(t, 2.5), np.percentile(t, 97.5), ntest).reshape(-1, 1) + + y = g_true(x, z, t) # g = lambda x, z, p: storeg(x, p) # doesn't use z + y_true = y.flatten() + if x.size == 0: + test_fn = os.path.join(os.getcwd(), "experiments", "data", + "demand_test_univariate_N{}_cor{}".format(ntest, ypcor)) + else: + test_fn = os.path.join(os.getcwd(), "experiments", "data", "demand_test_N{}_cor{}".format(ntest, ypcor)) + + np.savez(test_fn, x=x, z=z, t=t, y_true=y_true) + return x, z, t, y_true + + +def monte_carlo_error( + g_hat, data_fn, ntest=40, has_latent=False, debug=False, results_fn=None, resume_fn=None, train_treat=None): + #seed = np.random.randint(1e9) + seed = 123 # fix the seed to generate same test data oos try: # test = True ensures we draw test set images - x, z, t, y, g_true = data_fn(ntest, seed, test=True) + x, z, t, y, g_true = data_fn(ntest, seed, test=True) # the y generated here has error term e except ValueError: warnings.warn("Too few images, reducing test set size") ntest = int(ntest * 0.7) # test = True ensures we draw test set images x, z, t, y, g_true = data_fn(ntest, seed, test=True) - ## re-draw to get new independent treatment and implied response - t = np.linspace(np.percentile(t, 2.5),np.percentile(t, 97.5),ntest).reshape(-1, 1) - ## we need to make sure z _never_ does anything in these g functions (fitted and true) - ## above is necesary so that reduced form doesn't win + # re-draw to get new independent treatment and implied response, because we care about the change of price's effects + t = np.linspace(np.percentile(t, 2.5), np.percentile(t, 97.5), ntest).reshape(-1, 1) + # we need to make sure z _never_ does anything in these g functions (fitted and true) + # above is necesary so that reduced form doesn't win if has_latent: x_latent, _, _, _, _ = data_fn(ntest, seed, images=False) y = g_true(x_latent, z, t) else: - y = g_true(x, z, t) + y = g_true(x, z, t) # g = lambda x, z, p: storeg(x, p) # doesn't use z, without error term e y_true = y.flatten() y_hat = g_hat(x, z, t).flatten() + + if results_fn: + # when using results file, we will compare the performance on the same test data + # if not resume_fn: + # np.savez(os.path.join(os.getcwd(), "experiments", "demand_test_N{}".format(ntest)), x, z, t, y_true) + + # results_file like demand_results_N + results_fn = os.path.join(os.getcwd(), "experiments", "results", results_fn) + np.savez(results_fn, y_hat=y_hat, y_true=y_true) + return ((y_hat - y_true)**2).mean() @@ -36,7 +77,7 @@ def loadmnist(): ''' Load the mnist data once into global variables X_mnist and y_mnist. ''' - from keras.datasets import mnist + from tensorflow.keras.datasets import mnist global X_mnist global y_mnist train, test = mnist.load_data() @@ -50,6 +91,7 @@ def loadmnist(): X_mnist.append(X[idx, :, :]) y_mnist.append(y[idx]) + def get_images(digit, n, seed=None, testset=False): if X_mnist is None: loadmnist() @@ -61,13 +103,15 @@ def get_images(digit, n, seed=None, testset=False): if n > n_i: raise ValueError('You requested %d images of digit %d when there are \ only %d unique images in the %s set.' % (n, digit, n_i, 'test' if testset else 'training')) - return X_i[perm[0:n], :, :].reshape((n,i*j)) + return X_i[perm[0:n], :, :].reshape((n, i*j)) + def one_hot(col, **kwargs): - z = col.reshape(-1,1) + z = col.reshape(-1, 1) enc = OneHotEncoder(sparse=False, **kwargs) return enc.fit_transform(z) + def get_test_valid_train(generator, n, batch_size=128, seed=123, **kwargs): x, z, t, y, g = generator(n=int(n*0.6), seed=seed, **kwargs) train = prepare_datastream(x, z, t, y, True, batch_size, **kwargs) @@ -77,32 +121,37 @@ def get_test_valid_train(generator, n, batch_size=128, seed=123, **kwargs): test = prepare_datastream(x, z, t, y, False, batch_size, **kwargs) return train, valid, test, g + def sensf(x): return 2.0*((x - 5)**4 / 600 + np.exp(-((x - 5)/0.5)**2) + x/10. - 2) + def emocoef(emo): emoc = (emo * np.array([1., 2., 3., 4., 5., 6., 7.])[None, :]).sum(axis=1) return emoc + psd = 3.7 pmu = 17.779 -ysd = 158.#292. +ysd = 158. # 292. ymu = -292.1 + def storeg(x, price): emoc = emocoef(x[:, 1:]) time = x[:, 0] - g = sensf(time)*emoc*10. + (emoc*sensf(time)-2.0)*(psd*price.flatten() + pmu) + g = sensf(time)*emoc*10. + (emoc*sensf(time)-2.0)*(psd*price.flatten() + pmu) # h(x,p) y = (g - ymu)/ysd return y.reshape(-1, 1) + def demand(n, seed=1, ynoise=1., pnoise=1., ypcor=0.8, use_images=False, test=False): rng = np.random.RandomState(seed) # covariates: time and emotion time = rng.rand(n) * 10 emotion_id = rng.randint(0, 7, size=n) - emotion = one_hot(emotion_id, n_values=7) + emotion = one_hot(emotion_id) if use_images: idx = np.argsort(emotion_id) emotion_feature = np.zeros((0, 28*28)) @@ -119,19 +168,21 @@ def demand(n, seed=1, ynoise=1., pnoise=1., ypcor=0.8, use_images=False, test=Fa # z -> price v = rng.randn(n)*pnoise - price = sensf(time)*(z + 3) + 25. + price = sensf(time)*(z + 3) + 25. price = price + v price = (price - pmu)/psd # true observable demand function x = np.concatenate([time.reshape((-1, 1)), emotion_feature], axis=1) x_latent = np.concatenate([time.reshape((-1, 1)), emotion], axis=1) - g = lambda x, z, p: storeg(x, p) # doesn't use z - # errors + def g(x, z, p): + return storeg(x, p) # doesn't use z + + # errors e = (ypcor*ynoise/pnoise)*v + rng.randn(n)*ynoise*np.sqrt(1-ypcor**2) e = e.reshape(-1, 1) - + # response y = g(x_latent, None, price) + e @@ -143,10 +194,14 @@ def demand(n, seed=1, ynoise=1., pnoise=1., ypcor=0.8, use_images=False, test=Fa def linear_data(n, seed=None, sig_d=0.5, sig_y=2, sig_t=1.5, - alpha=4, noiseless_t=False, **kwargs): + alpha=4, noiseless_t=False, **kwargs): rng = np.random.RandomState(seed) - nox = lambda z, d: z + 2*d - house_price = lambda alpha, d, nox_val: alpha + 4*d + 2*nox_val + + def nox(z, d): + return z + 2*d + + def house_price(alpha, d, nox_val): + return alpha + 4*d + 2*nox_val d = rng.randn(n) * sig_d law = rng.randint(0, 2, n) @@ -158,13 +213,56 @@ def linear_data(n, seed=None, sig_d=0.5, sig_y=2, sig_t=1.5, z = law.reshape((-1, 1)) x = np.zeros((n, 0)) y = (house_price(alpha, d, t) + sig_y*rng.randn(n) - 5.)/5. - g_true = lambda x, z, t: house_price(alpha, 0, t) + def g_true(x, z, t): return house_price(alpha, 0, t) return x, z, t.reshape((-1, 1)), y.reshape((-1, 1)), g_true +pmu1 = 2 +psd1 = 2.25 +ymu1 = 8.1 +ysd1 = 4.5 + + +def storeg_uniivariate(p): + g = 4+2*(psd1*p.flatten() + pmu1) + y = (g - ymu1)/ysd1 + return y.reshape(-1, 1) + + +def demand_univariate(n, seed=1, ynoise=1., pnoise=1., ypcor=0.8, use_images=False, test=False): + rng = np.random.RandomState(seed) + + # random instrument + z = rng.randn(n) + + # z -> price + v = rng.randn(n)*pnoise + price = 2*z+2 + price = price + v + price = (price - pmu1)/psd1 + + def g(x, z, p): + return storeg_uniivariate(p) # doesn't use z + + x = np.zeros((n, 0)) + # errors + e = (ypcor*ynoise/pnoise)*v + rng.randn(n)*ynoise*np.sqrt(1-ypcor**2) + e = e.reshape(-1, 1) + + # response + y = g(x, None, price) + e + + return (x, + z.reshape((-1, 1)), + price.reshape((-1, 1)), + y.reshape((-1, 1)), + g) + + def main(): pass + if __name__ == '__main__': import sys sys.exit(int(main() or 0)) diff --git a/experiments/demand_simulation.py b/experiments/demand_simulation.py index 892e081..b9894ba 100644 --- a/experiments/demand_simulation.py +++ b/experiments/demand_simulation.py @@ -1,6 +1,11 @@ from __future__ import print_function - +import data_generator +import numpy +import argparse import warnings +import time +import os +import pickle from deepiv.models import Treatment, Response import deepiv.architectures as architectures @@ -8,32 +13,49 @@ import tensorflow as tf -from keras.layers import Input, Dense -from keras.models import Model -from keras.layers.merge import Concatenate +from tensorflow.keras.layers import Input, Dense +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Concatenate +# tf.executing_eagerly() -import numpy +parser = argparse.ArgumentParser(description='deman simulation') +parser.add_argument('--n', help='Number of training samples', default=5000, type=int) +parser.add_argument('--n_test', help='Number of test samples', default=5000, type=int) +parser.add_argument('--ypcor', help='correlation between p and e', default=0.5, type=float) +parser.add_argument('--seed', help='Random seed', default=1, type=int) +parser.add_argument('--unbiased', default=False, action="store_true") +parser.add_argument('--samples_per_batch', default=2, type=int) + +parser.add_argument('--results_fn', help='Results file', default='', type=str) +parser.add_argument('--test_fn', default='', type=str) +args = parser.parse_args() -import data_generator -n = 5000 +n = args.n dropout_rate = min(1000./(1000. + n), 0.5) -epochs = int(1500000./float(n)) # heuristic to select number of epochs -epochs = 300 +epochs = int(1500000./float(n)) # heuristic to select number of epochs +# epochs = 300 # 300 batch_size = 100 images = False -def datafunction(n, s, images=images, test=False): - return data_generator.demand(n=n, seed=s, ypcor=0.5, use_images=images, test=test) -x, z, t, y, g_true = datafunction(n, 1) +def datafunction(n, s, images=images, test=False, ypcor=0.5, ynoise=1.): + return data_generator.demand(n=n, seed=s, ypcor=ypcor, ynoise=ynoise, use_images=images, test=test) + + +# g_true is the ture function, t is the treatment (price in the paper) +x, z, t, y, g_true = datafunction(n, 1, ypcor=args.ypcor) + +x_test, z_test, t_test, y_test = data_generator.load_test_data( + test_fn=args.test_fn, data_fn=datafunction, ntest=args.n_test, ypcor=args.ypcor) # to keep consistent, using same seed as 1234 + print("Data shapes:\n\ Features:{x},\n\ Instruments:{z},\n\ Treament:{t},\n\ -Response:{y}".format(**{'x':x.shape, 'z':z.shape, - 't':t.shape, 'y':y.shape})) +Response:{y}".format(**{'x': x.shape, 'z': z.shape, + 't': t.shape, 'y': y.shape})) # Build and fit treatment model instruments = Input(shape=(z.shape[1],), name="instruments") @@ -45,12 +67,23 @@ def datafunction(n, s, images=images, test=False): act = "relu" n_components = 10 +# first step + + +def treatment_output(x): + return densities.mixture_of_gaussian_output(x, n_components) # Concatenate(axis=1)([pi, mu, log_sig]) -est_treat = architectures.feed_forward_net(treatment_input, lambda x: densities.mixture_of_gaussian_output(x, n_components), + +print(treatment_input) +start = time.time() +# est_treat is Concatenate(axis=1)([pi, mu, log_sig]) +est_treat = architectures.feed_forward_net(treatment_input, treatment_output, hidden_layers=hidden, dropout_rate=dropout_rate, l2=0.0001, activations=act) +print("Input", instruments.shape) +print("est_treat", est_treat.shape) treatment_model = Treatment(inputs=[instruments, features], outputs=est_treat) treatment_model.compile('adam', loss="mixture_of_gaussians", @@ -58,11 +91,11 @@ def datafunction(n, s, images=images, test=False): treatment_model.fit([z, x], t, epochs=epochs, batch_size=batch_size) -# Build and fit response model +# Build and fit response model, t is the treatment -treatment = Input(shape=(t.shape[1],), name="treatment") +treatment = Input(shape=(t.shape[1],), name="treatment") # placeholder for treatment from treatment model response_input = Concatenate(axis=1)([features, treatment]) - +print("response input shape:", response_input.shape) est_response = architectures.feed_forward_net(response_input, Dense(1), activations=act, hidden_layers=hidden, @@ -72,10 +105,45 @@ def datafunction(n, s, images=images, test=False): response_model = Response(treatment=treatment_model, inputs=[features, treatment], outputs=est_response) -response_model.compile('adam', loss='mse') -response_model.fit([z, x], y, epochs=epochs, verbose=1, - batch_size=batch_size, samples_per_batch=2) +# response_model.compile('adam', loss='mse') # unbiased_gradient=True, batch_size=batch_size) +response_model.compile('adam', loss='mse', unbiased_gradient=True, batch_size=batch_size) +response_history = response_model.fit([z, x], y, epochs=epochs, verbose=1, batch_size=batch_size, + samples_per_batch=2, validation_data=([x_test, t_test], y_test)) + +if not response_history: + response_history = response_model.history.history +else: + response_history = response_history.history + +response_fn = "./experiments/results/response_history_N{}_cor{}Dict".format(args.n, args.ypcor) + +response_fn = response_fn+"_unbiased" if args.unbiased else response_fn +response_fn = response_fn+"_S{}".format(args.samples_per_batch) if args.samples_per_batch > 2 else response_fn + +print("response fn: {}".format(response_fn)) +with open(response_fn, "wb") as file_pi: + pickle.dump(response_history, file_pi) -oos_perf = data_generator.monte_carlo_error(lambda x,z,t: response_model.predict([x,t]), datafunction, has_latent=images, debug=False) + +end = time.time() +print("total training time of sample size: {} is {}".format(args.n, end-start)) + +results_fn = args.results_fn +if args.results_fn: + results_fn = args.results_fn + "_N{}_P{}".format(args.n, args.ypcor) + results_fn = results_fn+"_unbiased" if args.unbiased else results_fn + results_fn = results_fn+"_S{}".format(args.samples_per_batch) if args.samples_per_batch > 2 else results_fn + print("results_fn : {}".format(results_fn)) + + +# monte_carlo_error(g_hat, data_fn, ntest=5000, has_latent=False, debug=False): +oos_perf = data_generator.monte_carlo_error(lambda x, z, t: response_model.predict( + [x, t]), datafunction, ntest=args.n_test, has_latent=images, debug=False, results_fn=results_fn) print("Out of sample performance evaluated against the true function: %f" % oos_perf) + + +# prepare_file("./results/DeepIV_results.csv") +# with open("DeepIV_results.csv", 'a') as f: +# f.write('%d,%d,%f,%f\n' % (args.n_samples, args.seed, args.endo, oos_perf)) +# diff --git a/experiments/demand_simulation_mnist.py b/experiments/demand_simulation_mnist.py index 5153ad6..04f5531 100644 --- a/experiments/demand_simulation_mnist.py +++ b/experiments/demand_simulation_mnist.py @@ -6,39 +6,42 @@ import deepiv.architectures as architectures import deepiv.densities as densities -from keras.layers import Input, Dense, Reshape -from keras.models import Model -from keras.layers.merge import Concatenate -import keras.backend as K +from tensorflow.keras.layers import Input, Dense, Reshape +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Concatenate +import tensorflow.keras.backend as K import numpy import data_generator -def conv_embedding(images, output, other_features = [], dropout_rate=0.1, + +def conv_embedding(images, output, other_features=[], dropout_rate=0.1, embedding_dropout=0.1, embedding_l2=0.05, constrain_norm=True): print("Building conv net") x_embedding = architectures.convnet(images, Dense(64, activation='linear'), - dropout_rate=embedding_dropout, - activations='relu', - l2_rate=embedding_l2, constrain_norm=constrain_norm) + dropout_rate=embedding_dropout, + activations='relu', + l2_rate=embedding_l2, constrain_norm=constrain_norm) if len(other_features) > 0: embedd = Concatenate(axis=1)([x_embedding] + other_features) else: embedd = x_embedding out = architectures.feed_forward_net(embedd, output, - hidden_layers=[32], - dropout_rate=dropout_rate, - activations='relu', constrain_norm=constrain_norm) + hidden_layers=[32], + dropout_rate=dropout_rate, + activations='relu', constrain_norm=constrain_norm) return out + n = 5000 dropout_rate = min(1000./(1000. + n), 0.5) embedding_dropout = 0.1 embedding_l2 = 0.1 epochs = int(1500000./float(n)) +epochs = 3 batch_size = 100 x, z, t, y, g_true = data_generator.demand(n=n, seed=1, ypcor=0.5, use_images=True, test=False) @@ -47,8 +50,8 @@ def conv_embedding(images, output, other_features = [], dropout_rate=0.1, Features:{x},\n\ Instruments:{z},\n\ Treament:{t},\n\ -Response:{y}".format(**{'x':x.shape, 'z':z.shape, - 't':t.shape, 'y':y.shape})) +Response:{y}".format(**{'x': x.shape, 'z': z.shape, + 't': t.shape, 'y': y.shape})) # Build and fit treatment model if K.image_data_format() == "channels_first": @@ -57,14 +60,16 @@ def conv_embedding(images, output, other_features = [], dropout_rate=0.1, image_shape = (28, 28, 1) images = Input(shape=(28 * 28,), name='treat_images') -image_reshaped = Reshape(image_shape)(images) # reshape +image_reshaped = Reshape(image_shape)(images) # reshape time = Input(shape=(1,), name='treat_time') instruments = Input(shape=(z.shape[1],), name='treat_instruments') -mix_gaussian_output = lambda x: densities.mixture_of_gaussian_output(x, 10) + +def mix_gaussian_output(x): return densities.mixture_of_gaussian_output(x, 10) + treatment_output = conv_embedding(image_reshaped, mix_gaussian_output, - [time, instruments], + [time, instruments], dropout_rate=dropout_rate, embedding_dropout=embedding_dropout, embedding_l2=embedding_l2) @@ -75,7 +80,7 @@ def conv_embedding(images, output, other_features = [], dropout_rate=0.1, loss="mixture_of_gaussians", n_components=10) -treatment_model.fit([z, x[:,0:1], x[:,1:]], t, epochs=epochs, batch_size=batch_size) +treatment_model.fit([z, x[:, 0:1], x[:, 1:]], t, epochs=epochs, batch_size=batch_size) treatment_model.save("demand_mnist_treatment.hd5") # Build and fit response model @@ -83,19 +88,23 @@ def conv_embedding(images, output, other_features = [], dropout_rate=0.1, treatment = Input(shape=(t.shape[1],), name="treatment") out_res = conv_embedding(image_reshaped, Dense(1, activation='linear'), [time, treatment], - dropout_rate=dropout_rate, embedding_dropout=embedding_dropout, embedding_l2=embedding_l2) + dropout_rate=dropout_rate, embedding_dropout=embedding_dropout, embedding_l2=embedding_l2) +# THIS PART IS WRONG response_model = Response(treatment=treatment_model, inputs=[time, images, treatment], outputs=out_res) response_model.compile('adam', loss='mse', unbiased_gradient=True, batch_size=batch_size) -response_model.fit([z, x[:,0:1], x[:,1:]], y, epochs=epochs, verbose=1, +response_model.fit([z, x[:, 0:1], x[:, 1:]], y, epochs=epochs, verbose=1, batch_size=batch_size, samples_per_batch=2) treatment_model.save("demand_mnist_response.hd5") + def datafunction(n, s, images=True, test=False): return data_generator.demand(n=n, seed=s, ypcor=0.5, use_images=images, test=test) -oos_perf = data_generator.monte_carlo_error(lambda x,z,t: response_model.predict([x[:,0:1], x[:,1:],t]), datafunction, has_latent=True, debug=False) + +oos_perf = data_generator.monte_carlo_error(lambda x, z, t: response_model.predict( + [x[:, 0:1], x[:, 1:], t]), datafunction, has_latent=True, debug=False) print("Out of sample performance evaluated against the true function: %f" % oos_perf) diff --git a/experiments/linear.py b/experiments/linear.py index 9036777..dee68ea 100644 --- a/experiments/linear.py +++ b/experiments/linear.py @@ -3,8 +3,8 @@ from deepiv.models import Treatment, Response import deepiv.densities as densities -from keras.layers import Input, Dense -from keras.models import Model +from tensorflow.keras.layers import Input, Dense +from tensorflow.keras.models import Model import data_generator @@ -13,12 +13,12 @@ x, z, t, y, g_true = data_generator.linear_data(n=1000, seed=1) print("Starting experiment with linear data\n" + "-"*50 + -"\nData shapes:\n\ + "\nData shapes:\n\ Features:{x},\n\ Instruments:{z},\n\ Treament:{t},\n\ -Response:{y}".format(**{'x':x.shape, 'z':z.shape, - 't':t.shape, 'y':y.shape})) +Response:{y}".format(**{'x': x.shape, 'z': z.shape, + 't': t.shape, 'y': y.shape})) # Build and fit treatment model instruments = Input(shape=(z.shape[1],)) @@ -31,7 +31,7 @@ treatment_model.compile('adam', loss="mixture_of_gaussians", n_components=10) -treatment_model.fit([z],t, epochs=epochs) +treatment_model.fit([z], t, epochs=epochs) # Build and fit response model x = Dense(64, activation='relu')(instruments) @@ -42,10 +42,14 @@ outputs=est_resp) response_model.compile('adam', loss='mse') response_model.fit([z], y, epochs=epochs, verbose=1, - batch_size=100, samples_per_batch=2) + batch_size=100, samples_per_batch=2) + def datafunction(n, s, images=False, test=False): return data_generator.linear_data(n=n, seed=s) -oos_perf = data_generator.monte_carlo_error(lambda x,z,t: response_model.predict([t]), datafunction, has_latent=False, debug=False) + +oos_perf = data_generator.monte_carlo_error( + lambda x, z, t: response_model.predict([t]), + datafunction, has_latent=False, debug=False) print("Out of sample performance evaluated against the true function: %f" % oos_perf) diff --git a/experiments/twosls.py b/experiments/twosls.py index 7da3813..3c1be98 100644 --- a/experiments/twosls.py +++ b/experiments/twosls.py @@ -13,56 +13,60 @@ import data_generator parser = argparse.ArgumentParser(description='Description of your program') -parser.add_argument('-n','--n_samples', help='Number of training samples', default=1000, type=int) -parser.add_argument('-s','--seed', help='Random seed', default=1, type=int) +parser.add_argument('-n', '--n_samples', help='Number of training samples', default=5000, type=int) +parser.add_argument('-s', '--seed', help='Random seed', default=1, type=int) parser.add_argument('--endo', help='Endogeneity', default=0.5, type=float) parser.add_argument('--heartbeat', help='Use philly heartbeat', action='store_true') parser.add_argument('--results', help='Results file', default='twosls.csv') args = parser.parse_args() + def fit_twosls(x, z, t, y): ''' Two stage least squares with polynomial basis function. ''' - params = dict(poly__degree=range(1,4), + params = dict(poly__degree=range(1, 4), ridge__alpha=np.logspace(-5, 5, 11)) pipe = Pipeline([('poly', PolynomialFeatures()), - ('ridge', Ridge())]) + ('ridge', Ridge())]) stage_1 = GridSearchCV(pipe, param_grid=params, cv=5) if z.shape[1] > 0: - X = np.concatenate([x,z], axis=1) + X = np.concatenate([x, z], axis=1) else: X = z - stage_1.fit(X,t) + stage_1.fit(X, t) t_hat = stage_1.predict(X) - print("First stage paramers: " + str(stage_1.best_params_ )) + print("First stage paramers: " + str(stage_1.best_params_)) pipe2 = Pipeline([('poly', PolynomialFeatures()), - ('ridge', Ridge())]) + ('ridge', Ridge())]) stage_2 = GridSearchCV(pipe2, param_grid=params, cv=5) - X2 = np.concatenate([x,t_hat], axis=1) + X2 = np.concatenate([x, t_hat], axis=1) stage_2.fit(X2, y) print("Best in sample score: %f" % stage_2.score(X2, y)) - print("Second stage paramers: " + str(stage_2.best_params_ )) + print("Second stage paramers: " + str(stage_2.best_params_)) - def g_hat(x,z,t): + def g_hat(x, z, t): X_new = np.concatenate([x, t], axis=1) return stage_2.predict(X_new) return g_hat + def prepare_file(filename): if not os.path.exists(filename): with open(filename, 'w') as f: f.write('n,seed,endo,mse\n') -df = lambda n, s, test: data_generator.demand(n, s, ypcor=args.endo, test=test) -x,z,t,y,g = df(args.n_samples, args.seed, False) -g_hat = fit_twosls(x,z,t,y) + +def df(n, s, test): return data_generator.demand(n, s, ypcor=args.endo, test=test) + + +x, z, t, y, g = df(args.n_samples, args.seed, False) +g_hat = fit_twosls(x, z, t, y) oos_perf = data_generator.monte_carlo_error(g_hat, df, has_latent=False, debug=False) print("Out of sample performance evaluated against the true function: %f" % oos_perf) prepare_file(args.results) with open(args.results, 'a') as f: - f.write('%d,%d,%f,%f\n' % (args.n_samples, args.seed,args.endo, oos_perf)) - + f.write('%d,%d,%f,%f\n' % (args.n_samples, args.seed, args.endo, oos_perf)) diff --git a/experiments/univariate.py b/experiments/univariate.py new file mode 100644 index 0000000..f286669 --- /dev/null +++ b/experiments/univariate.py @@ -0,0 +1,150 @@ +from __future__ import print_function +import data_generator +import numpy +import argparse +import warnings +import time +import os +import pickle + +from deepiv.models import Treatment, Response +import deepiv.architectures as architectures +import deepiv.densities as densities + +import tensorflow as tf + +from tensorflow.keras.layers import Input, Dense +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Concatenate +# tf.executing_eagerly() + +parser = argparse.ArgumentParser(description='deman simulation') +parser.add_argument('--n', help='Number of training samples', default=5000, type=int) +parser.add_argument('--n_test', help='Number of test samples', default=5000, type=int) +parser.add_argument('--ypcor', help='correlation between p and e', default=0.5, type=float) +parser.add_argument('--seed', help='Random seed', default=1, type=int) +parser.add_argument('--unbiased', default=False, action="store_true") +parser.add_argument('--samples_per_batch', default=2, type=int) + +parser.add_argument('--results_fn', help='Results file', default='', type=str) +parser.add_argument('--test_fn', default='', type=str) +args = parser.parse_args() + + +n = args.n +dropout_rate = min(1000./(1000. + n), 0.5) +epochs = int(1500000./float(n)) # heuristic to select number of epochs +epochs = 30 # 300 +batch_size = 100 +images = False + + +def datafunction(n, s, images=images, test=False, ypcor=0.5, ynoise=1.): + return data_generator.demand_univariate(n=n, seed=s, ypcor=ypcor, ynoise=ynoise, use_images=images, test=test) + + +# g_true is the ture function, t is the treatment (price in the paper) +x, z, t, y, g_true = datafunction(n, 1, ypcor=args.ypcor) + +x_test, z_test, t_test, y_test = data_generator.load_test_data( + test_fn=args.test_fn, data_fn=datafunction, ntest=args.n_test, ypcor=args.ypcor) # to keep consistent, using same seed as 1234 + + +print("Data shapes:\n\ +Features:{x},\n\ +Instruments:{z},\n\ +Treament:{t},\n\ +Response:{y}".format(**{'x': x.shape, 'z': z.shape, + 't': t.shape, 'y': y.shape})) + +# Build and fit treatment model +instruments = Input(shape=(z.shape[1],), name="instruments") +features = Input(shape=(x.shape[1],), name="features") +treatment_input = Concatenate(axis=1)([instruments, features]) + +hidden = [128, 64, 32] + +act = "relu" + +n_components = 10 +# first step + + +def treatment_output(x): + return densities.mixture_of_gaussian_output(x, n_components) # Concatenate(axis=1)([pi, mu, log_sig]) + + +print(treatment_input) +start = time.time() +# est_treat is Concatenate(axis=1)([pi, mu, log_sig]) +est_treat = architectures.feed_forward_net(treatment_input, treatment_output, + hidden_layers=hidden, + dropout_rate=dropout_rate, l2=0.0001, + activations=act) + +print("Input", instruments.shape) +print("est_treat", est_treat.shape) +treatment_model = Treatment(inputs=[instruments, features], outputs=est_treat) +treatment_model.compile('adam', + loss="mixture_of_gaussians", + n_components=n_components) + +treatment_model.fit([z, x], t, epochs=epochs, batch_size=batch_size) + +# Build and fit response model, t is the treatment + +treatment = Input(shape=(t.shape[1],), name="treatment") # placeholder for treatment from treatment model +response_input = Concatenate(axis=1)([features, treatment]) +print("response input shape:", response_input.shape) +est_response = architectures.feed_forward_net(response_input, Dense(1), + activations=act, + hidden_layers=hidden, + l2=0.001, + dropout_rate=dropout_rate) + +response_model = Response(treatment=treatment_model, + inputs=[features, treatment], + outputs=est_response) +# response_model.compile('adam', loss='mse') # unbiased_gradient=True, batch_size=batch_size) +response_model.compile('adam', loss='mse', unbiased_gradient=True, batch_size=batch_size) +response_history = response_model.fit([z, x], y, epochs=epochs, verbose=1, batch_size=batch_size, + samples_per_batch=2, validation_data=([x_test, t_test], y_test)) + +if not response_history: + response_history = response_model.history.history +else: + response_history = response_history.history + +response_fn = "./experiments/results/response_history_univariate_N{}_cor{}Dict".format(args.n, args.ypcor) + +response_fn = response_fn+"_unbiased" if args.unbiased else response_fn +response_fn = response_fn+"_S{}".format(args.samples_per_batch) if args.samples_per_batch > 2 else response_fn + + +print("response fn: {}".format(response_fn)) +with open(response_fn, "wb") as file_pi: + pickle.dump(response_history, file_pi) + + +end = time.time() +print("total training time of sample size: {} is {}".format(args.n, end-start)) + +results_fn = args.results_fn +if args.results_fn: + results_fn += "_univariate" + results_fn = args.results_fn + "_N{}_P{}".format(args.n, args.ypcor) + results_fn = results_fn+"_unbiased" if args.unbiased else results_fn + results_fn = results_fn+"_S{}".format(args.samples_per_batch) if args.samples_per_batch > 2 else results_fn + print("results_fn : {}".format(results_fn)) + + +# monte_carlo_error(g_hat, data_fn, ntest=5000, has_latent=False, debug=False): +oos_perf = data_generator.monte_carlo_error(lambda x, z, t: response_model.predict( + [x, t]), datafunction, ntest=args.n_test, has_latent=images, debug=False, results_fn=results_fn) +print("Out of sample performance evaluated against the true function: %f" % oos_perf) + + +# prepare_file("./results/DeepIV_results.csv") +# with open("DeepIV_results.csv", 'a') as f: +# f.write('%d,%d,%f,%f\n' % (args.n_samples, args.seed, args.endo, oos_perf)) +# diff --git a/setup.py b/setup.py index c2f8d42..6a59a49 100644 --- a/setup.py +++ b/setup.py @@ -10,10 +10,10 @@ history = history_file.read() requirements = [ - "keras==2.0.6", - "tensorflow", - "sklearn", # required for comparing to linear - "h5py" # required for saving models + "keras==2.3.1", + "tensorflow==2.5", + "sklearn", # required for comparing to linear + "h5py" # required for saving models ] optimal_packages = { @@ -26,10 +26,10 @@ setup( name='deepiv', - version='0.1.0', + version='0.1.2', description="A package for counterfactual prediction using deep instrument variable methods", long_description=readme + '\n\n' + history, - author="Jason Hartford", + author="Jason Hartford, Xingrui Wang", author_email='jasonhar@cs.ubc.ca', url='https://github.com/jhartford/deepiv', packages=[ @@ -55,6 +55,7 @@ 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.7', ], test_suite='tests', tests_require=test_requirements diff --git a/tox.ini b/tox.ini index 146651c..bf17c9c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py26, py27, py33, py34, py35, flake8 +envlist = py37, flake8 [testenv:flake8] basepython=python @@ -10,7 +10,7 @@ commands=flake8 deepiv setenv = PYTHONPATH = {toxinidir}:{toxinidir}/deepiv -commands = python setup.py test +; commands = python setup.py test ; If you want to make tox run the tests with the same versions, create a ; requirements.txt with the pinned versions and uncomment the following lines: