-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added highway flow #1270
base: main
Are you sure you want to change the base?
added highway flow #1270
Conversation
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
1 similar comment
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
@googlebot I signed it!
…On Tue, 16 Mar 2021 at 18:34, google-cla[bot] ***@***.***> wrote:
Thanks for your pull request. It looks like this may be your first
contribution to a Google open source project (if not, look below for help).
Before we can look at your pull request, you'll need to sign a Contributor
License Agreement (CLA).
📝 *Please visit https://cla.developers.google.com/
<https://cla.developers.google.com/> to sign.*
Once you've signed (or fixed any issues), please reply here with @googlebot
I signed it! and we'll verify it.
------------------------------
What to do if you already signed the CLA Individual signers
- It's possible we don't have your GitHub username or you're using a
different email address on your commit. Check your existing CLA data
<https://cla.developers.google.com/clas> and verify that your email is
set on your git commits
<https://help.github.com/articles/setting-your-email-in-git/>.
Corporate signers
- Your company has a Point of Contact who decides which employees are
authorized to participate. Ask your POC to be added to the group of
authorized contributors. If you don't know who your Point of Contact is,
direct the Google project maintainer to go/cla#troubleshoot (Public
version <https://opensource.google/docs/cla/#troubleshoot>).
- The email used to register you as an authorized contributor must be
the email used for the Git commit. Check your existing CLA data
<https://cla.developers.google.com/clas> and verify that your email is
set on your git commits
<https://help.github.com/articles/setting-your-email-in-git/>.
- The email used to register you as an authorized contributor must
also be attached to your GitHub account
<https://github.com/settings/emails>.
ℹ️ *Googlers: Go here
<https://goto.google.com/prinfo/https%3A%2F%2Fgithub.com%2Ftensorflow%2Fprobability%2Fpull%2F1270>
for more info*.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#1270 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AGBDLSBJSNV5RGY3WCXSFKDTD6JEFANCNFSM4ZJAAIPA>
.
|
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
1 similar comment
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
name=name) | ||
|
||
self.width = width | ||
self.W = tf.Variable(np.random.normal(0, 0.01, (self.width, self.width)), trainable=True, dtype=tf.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In TFP classes we generally assume that __init__
is a lightweight method that has no 'graph side effects', meaning that it doesn't create any new Tensors or tf.Variables. This is nice because it means it's always cheap to construct an instance, but it does put some constraints on designs.
Another TFP convention (#6 in the style guide: https://github.com/tensorflow/probability/blob/master/STYLE_GUIDE.md) is that we prefer to use descriptive, plain-English names in place of mathematical notation, even if doing so makes the code painfully verbose. For example, we'd prefer weights
rather than W
, bias
rather than b
, diagonal
(or similar) rather than d
, residual_fraction
(or similar) rather than lambda
, etc. This makes it easier for a future reader or maintainer of this code (who might not have full mathematical context) to glance at a line of code and at least have some idea of what's going on. (if needed, we can also add a comment to describe the math specifically).
Here I'd suggest that the bijector __init__
could have a signature along the lines of __init__(self, weights, bias, diagonal_elements, unconstrained_residual_fractions, activation_fn=tf.nn.sigmoid, validate_args=False, name='tri_res_net')
, where we pass in the variables explicitly using descriptive names (it'd also be nice to accept arbitrary activation functions or None
, rather than just sigmoid, though that's not essential for a first pass).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify re taking variables as __init__
arguments: of course variables need to be created somewhere, but we'd probably move that to a wrapper function, which could have a signature close to what you have here. Something along the lines of:
def build_highway_flow_layer(width, activation):
return TriResNet(
weights=tf.Variable(np.random.normal(0, 0.01, (width, width)).astype(tf.float32)),
bias=tf.Variable(np.random.normal(0, 0.01, (width,)).astype(tf.float32)),
...,
activation=activation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Gianluigi -- Thanks for the PR, this looks like a great start! I left some initial comments.
def __init__(self, width, activation=True, validate_args=False, name='tri_res_net'): | ||
super(TriResNet, self).__init__( | ||
validate_args=validate_args, | ||
forward_min_event_ndims=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'forward_min_event_ndims' refers to the minimum rank of the input (or the difference in rank between the input and the output of log_det_jacobian). Here, if I understand correctly, the minimum input rank is 1 (a vector) and the LDJ is a scalar, so this should be 1.
name=name) | ||
|
||
self.width = width | ||
self.W = tf.Variable(np.random.normal(0, 0.01, (self.width, self.width)), trainable=True, dtype=tf.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use more descriptive variable names? E.g. weights
, bias
, convex_update
instead of w
, b
, cu
self.masku = tfe.numpy.tril(tf.ones((self.width, self.width)), -1) | ||
self.maskl = tfe.numpy.triu(tf.ones((self.width, self.width)), 1) | ||
|
||
def get_l(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tfp.util.TransformedVariable
could be useful here. Instead of defining pre_l
in the constructor and applying get_l
to transform it, you could define prior_weight = tfp.util.TransformedVariable(initial_value=np.random(...), bijector=tfb.Sigmoid())
in the constructor and just refer to that (here, naming the lambdas prior_weight
as we did in build_asvi_surrogate_posterior
. TransformedVariable
defers the transformation so as not to break gradients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update after reading Dave's comment: It could also work to pass predefined TransformedVariables
into the constructor (or just not use them, if you and he prefer) -- e.g. we could add a factory function that creates the TransformedVariables
and builds the bijector with the lighter-weight constructor. (For some more complex bijectors -- e.g. tfb.Glow
-- we do define variables in the constructor, so that might also be something we could consider.)
def get_l(self): | ||
return tf.nn.sigmoid(self.pre_l) | ||
|
||
def get_L(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, TransformedVariable
could help simplify the definition of L
and U
. Take a look at the FillScaleTril
bijector: https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/FillScaleTriL?version=nightly
That should let you define L
and U
in the constructor with TransformedVariable
, and also allow you to define only as many variable degrees of freedom as you need for the triangular portion, instead of defining width^2
and masking them.
It also lets you apply a bijector to the diagonal to ensure that it's positive (Softplus
is used by default.) (Doing it this way would mean you have trainable, nonzero diagonals for both L and U, and it looks like the current implementation has ones on the diagonal of L -- this strikes me as a reasonable change, but if you want to use TransformedVariable
+ FillScaleTril
and keep the ones on the diagonal, we can think through how best to do that)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My initial understanding of TransformedVariable
was that we specify our variable before the bijector, and what we get is the transformed variable according to the chosen bijector. However, it seems like it works the other way around, i.e. the initial_value should be defined as if it was already transformed by the bijector. Is my understanding correct?
If this is correct, then one way to define the matrix U could be :
initial_value = np.triu(np.random.normal(0., 1., (width, width)), 1) + np.diag(tf.math.softplus(np.random.normal(0, 1, (width,))))
fill_scale_tril_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus(), diag_shift=None)
transpose_bijector = tfb.Transpose(rightmost_transposed_ndims=2)
chain_of_bijectors = tfb.Chain([transpose_bijector, fill_scale_tril_bijector])
upper_diagonal_weights_matrix = tfp.util.TransformedVariable(initial_value=initial_value,
bijector=chain_of_bijectors,
trainable=True)
where we obtain initial_value
by doing the whole transformation ourselves.
Following the same procedure, then the residual_fraction
(lambda) could be defined as:
self.residual_fraction = tfp.util.TransformedVariable(
initial_value=tf.math.sigmoid(np.random.normal(0, 0.1, (1,))),
bijector=tfb.Sigmoid(), trainable=True)
Is this the correct way to use TransformedVariable
? And should the definition of the initial_value
be done differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good questions -- you're right that initial_value
(somewhat counterintuitively) is the initial value after the transformation. Since FillScaleTriL.inverse
ignores entries above the diagonal, you can define initial_value
just as np.random.uniform(0., 1.), (width, width))
and the value of the TransformedVariable
will be triangular (here, using uniform
instead of normal
to avoid negative values).
Instead of the Transpose
bijector, it might be more efficient to pass transpose_a=True
into matvec
, would that work? (Also, it looks like the matvec
args are reversed -- the matrix should be first, then the vector.)
residual_fraction
looks good, except it might be simpler to define initial_value
as a uniform
sample (or as a fixed value -- in our implementation, this is a parameter we expose to the user, with a default of 0.5). Also, tf.Variable
s are trainable by default, so you don't need to pass trainable=True
.
y = self.inv_f(y) | ||
|
||
y = tf.linalg.matmul(y - self.b, tf.linalg.inv(self.cu(self.get_U()))) | ||
y = tf.linalg.matmul(y, tf.linalg.inv(self.cu(self.get_L()))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it work to use use triangular_solve
instead of inverting and multiplying? https://www.tensorflow.org/api_docs/python/tf/linalg/triangular_solve
return x | ||
|
||
def _forward(self, x): | ||
x = tf.linalg.matmul(x, self.cu(self.get_L())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is x
a vector? If so, let's use matvec
instead of matmul
return y | ||
|
||
def _inverse_log_det_jacobian(self, y): | ||
return self._forward_log_det_jacobian(y) # TODO: is this correct? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only need to implement one or the other of ILDJ and FLDJ, so you can remove this method (and the Bijector
base class will calculate ILDJ as negative FLDJ)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sharing this, it looks great! I've left some initial comments just to give an idea of what we'll expect in a final version.
For a full PR, we'll also eventually need:
- A file
highway_flow_test.py
with unit tests. - Updates to the
BUILD
and__init__.py
files to refer to the new code. - Docstrings on all new classes and methods (can be one-line for short/trivial methods, but generally we'll want to describe the arguments and return values in a similar style to the rest of TFP).
You might want to hold off on, e.g., writing docstrings until we've converged further on the interface, but I'm happy to talk more about what any of these would look like.
def __init__(self, width, activation=True, validate_args=False, name='tri_res_net'): | ||
super(TriResNet, self).__init__( | ||
validate_args=validate_args, | ||
forward_min_event_ndims=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming that the values transformed by this bijector (arguments x
and y
to the forward and inverse methods, respectively) are intended to be vectors, then this should be forward_min_event_ndims=1
. (otherwise the base class will screw up your log-det-Jacobian calculations by trying to sum them).
name=name) | ||
|
||
self.width = width | ||
self.W = tf.Variable(np.random.normal(0, 0.01, (self.width, self.width)), trainable=True, dtype=tf.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify re taking variables as __init__
arguments: of course variables need to be created somewhere, but we'd probably move that to a wrapper function, which could have a signature close to what you have here. Something along the lines of:
def build_highway_flow_layer(width, activation):
return TriResNet(
weights=tf.Variable(np.random.normal(0, 0.01, (width, width)).astype(tf.float32)),
bias=tf.Variable(np.random.normal(0, 0.01, (width,)).astype(tf.float32)),
...,
activation=activation)
return y | ||
|
||
def _inverse_log_det_jacobian(self, y): | ||
return self._forward_log_det_jacobian(y) # TODO: is this correct? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't implement an inverse_log_det_jacobian, the Bijector base class defaults to
inverse_log_det_jacobian(y) = - forward_log_det_jacobian(x=inverse(y))
(which is always correct), so there's no need for an explicit implementation unless you have an approach that's more efficient or more numerically stable.
return l * I + (1 - l) * M | ||
|
||
def inv_f(self, y, N=20): | ||
# inverse with Newton iteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to support activation functions other than sigmoid, we have root-finding methods that can generically invert scalar functions:
https://www.tensorflow.org/probability/api_docs/python/tfp/math/find_root_chandrupatla
https://www.tensorflow.org/probability/api_docs/python/tfp/math/find_root_secant
The experimental ScalarFunctionWithInferredInverse bijector https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/bijectors/ScalarFunctionWithInferredInverse?version=nightly
is a wrapper around these methods that implements the bijector API, e.g., it'll automatically compute the log_det_jacobian (which for a scalar function is just the log absolute derivative) to give you an analogue of the derivative method that you wrote here for a sigmoid activation (self.df
).
l = self.get_l() | ||
return l + (1 - l) * tf.math.sigmoid(x) * (1 - tf.math.sigmoid(x)) | ||
|
||
def cu(self, M): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as above, we prefer descriptive, readable names: convex_update
instead of cu
, etc.
Also, most of these methods should probably be private (_convex_update
, etc) unless you expect them to be used directly by external code.
|
||
from tensorflow_probability.python import bijectors as tfb | ||
from tensorflow_probability.python import util | ||
from tensorflow_probability.python.experimental.bijectors import scalar_function_with_inferred_inverse | ||
|
||
def build_highway_flow_layer(width, residual_fraction_initial_value=0.5, activation_fn=None): | ||
# FIXME: should everything be in float32 or float64? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We typically try to infer the dtype that the user wants from the dtype of float-valued inputs. In this case you could write at the top:
residual_fraction_initial_value = tf.convert_to_tensor(residual_fraction_initial_value,
dtype_hint=tf.float32,
name='residual_fraction_initial_value')
dtype = residual_fraction_initial_value.dtype
This will use float32 by default, but adapt to float64 / float16 / etc. if the caller passes a Tensor value with explicit dtype (convert_to_tensor
is a no-op if the input is already a Tensor). Then use dtype
throughout when defining variables, e.g.,
bias=tf.Variable(tf.random.normal([width], mean=0., stddev=0.01, dtype=dtype))
etc.
|
||
def build_highway_flow_layer(width, residual_fraction_initial_value=0.5, activation_fn=None): | ||
# FIXME: should everything be in float32 or float64? | ||
# TODO: add control that residual_fraction_initial_value is between 0 and 1 | ||
return HighwayFlow( | ||
width=width, | ||
residual_fraction=tfp.util.TransformedVariable( | ||
residual_fraction=util.TransformedVariable( | ||
initial_value=np.asarray(residual_fraction_initial_value, dtype='float32'), | ||
bijector=tfb.Sigmoid()), | ||
activation_fn=activation_fn, | ||
bias=tf.Variable(np.random.normal(0., 0.01, (width,)), dtype=tf.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've been moving in TFP towards using stateless RNGs, which are required by JAX and now also supported by TF, and tend to be easier to reason about (see https://github.com/tensorflow/probability/blob/master/PRNGS.md if you want to learn more). The idiom is along the lines of:
# At top
from tensorflow_probability.python.internal import samplers
# Function defn
def build_highway_flow_layer(..., seed=None):
# Ensure that each part gets a different initialization.
bias_seed, upper_seed, lower_seed = samplers.split_seed(seed, n=3)
...
# Drawing initial values.
bias = tf.Variable(samplers.normal([width], mean=0., stddev=0.01, dtype=dtype, seed=bias_seed)),
upper_diagonal_weights_matrix = TransformedVariable(samplers.normal(..., seed=upper_seed), ...),
lower_diagonal_weights_matrix = TransformedVariable(samplers.normal(..., seed=lower_seed), ...))
tensorflow_probability/python/experimental/bijectors/highway_flow.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Gianluigi, thanks for the updates; this looks great! Almost all of my comments at this point are nitpicking about stylistic details---unfortunately, as a new contributor you're subject to some up-front overhead in conforming to our (somewhat arbitrary) style rules. Hopefully it's not too bad; I can attest that if you write enough TFP code it eventually becomes mostly automatic . :-)
Other than my specific comments, I think the code and tests look to be in generally good shape, so I think we're really close to being able to pull this in. Let me know if anything I said doesn't make sense (this is very likely), or if there's anything else I can help with.
Dave
return HighwayFlow( | ||
width=width, | ||
residual_fraction=util.TransformedVariable( | ||
initial_value=tf.convert_to_tensor(residual_fraction_initial_value), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should already have been converted to a Tensor; there's no need to do it again here.
bias=tf.Variable(samplers.normal((width,), 0., 0.01, seed=bias_seed), dtype=dtype), | ||
upper_diagonal_weights_matrix=util.TransformedVariable( | ||
initial_value=tf.experimental.numpy.tril(samplers.normal((width, width), 0., 1., seed=upper_seed), | ||
-1) + tf.linalg.diag( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer axis=-1
here (and in general, passing args by keyword outside of trivial cases).
tensorflow_probability/python/experimental/bijectors/highway_flow.py
Outdated
Show resolved
Hide resolved
|
||
#### References | ||
|
||
[1]: ADD REFERENCE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to add the reference, and ideally a simple usage example.
name='highway_flow'): | ||
super(HighwayFlow, self).__init__( | ||
validate_args=validate_args, | ||
forward_min_event_ndims=1, # FIXME: should this also be an argument of HighwayFlow __init__? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine as is. (i.e., no need to expose in HighwayFlow __init__
). It's standard for flows to operate on rank-1 values (vectors), and a user who wants to pass other types of Tensors can always use the Reshape bijector.
-bijector.inverse_log_det_jacobian(tf.identity(bijector.forward(x)), event_ndims=dim+1)) | ||
|
||
def testJacobianWithActivation(self): | ||
#activations = ['sigmoid', 'softplus', 'tanh', 'none'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as elsewhere, we can delete the commented code until we're ready to actually support this.
) | ||
|
||
self.evaluate([v.initializer for v in bijector.trainable_variables]) | ||
x = tf.ones((batch_size, width)) # * samplers.uniform((batch_size, width), -10., 10., seed=seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a very slight preference for using (seeded) random samples here instead of ones (here and elsewhere).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so far I manually set the seed to 1. when using seed = test_util.test_seed(sampler_type='stateless')
I get the error absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --fixed_seed before flags were parsed.
which I haven't solved yet.
|
||
manual_grad = (y1 - y2) / (2 * h) | ||
|
||
diff = tf.math.abs(tf_grad - manual_grad) / tf.reduce_max((eps, tf.math.abs(tf_grad) + tf.math.abs(manual_grad))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably clearer (and easier) to write just self.assertAllClose(tf_grad, manual_grad, rtol=1e-4)
in place of this explicit calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed this but then to pass the test I needed to increase h to 1e-3. Could you please double-check that the gradient test I implemented is reliable?
activation_fn=tf.nn.softplus, | ||
bias=tf.Variable(0., dtype=dtype), | ||
upper_diagonal_weights_matrix=tf.Variable(tf.eye(width)), | ||
lower_diagonal_weights_matrix=tf.Variable(tf.eye(width)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think any of these parameters (including residual_fraction) need to be Variables, since you're not actually updating them. You can just write residual_fraction = tf.constant(0.5)
,
upper_diagonal_weights_matrix=tf.eye(width)
, etc. Variables are special in that they're automatically watched by gradient tapes, but you don't need that since you're already asking the tape to watch residual_fraction
specifically.
@@ -17,9 +17,13 @@ | |||
from tensorflow_probability.python.bijectors.ldj_ratio import inverse_log_det_jacobian_ratio | |||
from tensorflow_probability.python.experimental.bijectors.distribution_bijectors import make_distribution_bijector | |||
from tensorflow_probability.python.experimental.bijectors.scalar_function_with_inferred_inverse import ScalarFunctionWithInferredInverse | |||
from tensorflow_probability.python.experimental.bijectors.highway_flow import build_highway_flow_layer, HighwayFlow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Google's Python style prefers only one import per line, even if that means redundant lines. Here, you'll need to write:
from tensorflow_probability.python.experimental.bijectors.highway_flow import build_highway_flow_layer
from tensorflow_probability.python.experimental.bijectors.highway_flow import HighwayFlow
…FunctionWithInferredInverse. This originally arose as an issue in #1270 where we wanted to define the bijector that interpolates between a sigmoid and the identity function: y = fn(x, a) = a * x + (1 - a) * sigmoid(x) (for a in [0, 1]) with a trainable coefficient `a`, but the checked-in version of this bijector doesn't provide a gradient to `a`. This version only accepts scalar parameters, because 1. This was already complicated enough. 2. I *think* that supporting parameters of arbitrary rank would require the caller to tell us the rank of the parameter (so we know how much of it to treat as 'batch shape'), which seems like a mess. 3. Going vector-only instead of scalar-only would work too, and allows faster math if there are lots of parameters, but it's awkward to deal with boxing and unboxing vectors in the common case where the parameter really is semantically a scalar. 4. In dire straits, you could simulate a vector of fixed size by passing multiple scalars. I'm not certain this current API will ultimately be exactly the right thing, but that's what experimental is for. :-) PiperOrigin-RevId: 369308653
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One final (I think) round of comments.
activation_fn=False, | ||
gate_first_n=-1, | ||
seed=None): | ||
"""Builds HighwayFlow making sure that all the requirements ar satisfied. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ar -> are
width: Input dimension of the bijector. | ||
residual_fraction_initial_value: Initial value for gating parameter, must be | ||
between 0 and 1. | ||
activation_fn: Whether or not use SoftPlus activation function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: unindent this to match the other args
name='residual_fraction_initial_value') | ||
dtype = residual_fraction_initial_value.dtype | ||
|
||
bias_seed, upper_seed, lower_seed, diagonal_seed = samplers.split_seed( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like diagonal_seed is unused, can it be removed?
|
||
For more details on Highway Flow and Cascading Flows see [1]. | ||
|
||
#### Usage example: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need for a colon at the end of a heading, this can just be #### Usage example
.
"Automatic variational inference with | ||
cascading flows." arXiv preprint arXiv:2102.04801 (2021). | ||
|
||
Attributes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since these are args to __init__
, our convention is to put them in the __init__
docstring below, rather than the class-level docstring.
axis=0) * self.bias | ||
y = tf.linalg.triangular_solve(tf.transpose( | ||
self._convex_update(self.upper_diagonal_weights_matrix)), | ||
tf.linalg.matrix_transpose(y), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we transposing y
here? Transposes are expensive, since they actually change the data layout in memory---as opposed to other ops like reshape, expand_dims, squeeze, etc., which just change metadata---so it's good to avoid them when possible. (also it just feels weird to be transposing y
since it's conceptually a vector).
I think you can remove all transpose
and matrix_transpose
calls from this method, using a combination of:
- Work with
y
of shape[..., width, 1]
throughout, so that you can feed it directly totriangular_solve
calls. That is, instead ofy = expand_dims(y, 0)
at the beginning, doy = y[..., tf.newaxis]
to add the extra dimension at the end rather than at the beginning. (as a bonus, you can do this unconditionally so you don't need theadded_batch
logic any more). You might need to sprinkle in a few more[..., tf.newaxis]
expansions for other Tensors, egbias
, to get the dimensions to line up, but this is costless.
- Use the
adjoint
flag totriangular_solve
. Since transposes are expensive, most TF matrix ops give you the option to treat a matrix as transposed without explicitly constructing the transpose. For example,triangular_solve(A, B, adjoint=True)
is equivalent to, and more efficient than,triangular_solve(tf.linalg.matrix_transpose(A), B)
.
for dim in range(2): | ||
if dim == 0: | ||
# Test generic case with scalar input | ||
x = tf.ones((width,)) * samplers.uniform((width,), minval=-1., |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiplying the samples by tf.ones
here doesn't do anything; you can remove it (also in similar lines elsewhere).
|
||
# pylint: disable=protected-access | ||
bijector._residual_fraction = residual_fraction + h | ||
y1 = tf.reduce_mean(target.log_prob(bijector.forward(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the bijector cache will give a wrong answer here (and two lines below), since it doesn't know that _residual_fraction
has changed from the previous call. You'll probably need to pass tf.identity(x)
.
|
||
HighwayFlow interpolates the input `X` with the transformations at each step | ||
of the bjiector. The Highway Flow can be used as building block for a | ||
Cascading flow [1 or as a generic normalizing flow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: [1]
of the bjiector. The Highway Flow can be used as building block for a | ||
Cascading flow [1 or as a generic normalizing flow. | ||
|
||
The transformation consists in a convex update between the input `X` and a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 'consists in' -> 'consists of'
def __init__(self, residual_fraction, activation_fn, bias, | ||
upper_diagonal_weights_matrix, | ||
lower_diagonal_weights_matrix, | ||
gate_first_n, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this! Can you also
- Make
gate_first_n=None
the default in the signature, and - Update the docstring to document the default behavior.
?
# Log determinant term from the upper matrix. Note that the log determinant | ||
# of the lower matrix is zero. | ||
|
||
fldj = tf.zeros(x.shape[:-1], dtype=self.dtype) + tf.reduce_sum( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x.shape
should be ps.shape(x)
(also in the analogous line of _augmented_inverse
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the comments on the default behaviour, I have added a simple line for the HighwayFlow __init__
, but maybe you meant something more sophisticated?
|
||
if self.activation_fn: | ||
fldj += tf.reduce_sum(tf.math.log(self._derivative_of_softplus(x[0])), | ||
-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: prefer the keyword argument, axis=-1
(similarly elsewhere)
PiperOrigin-RevId: 374964013
No description provided.