-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Description
Hi all,
I'm trying to compute a posterior predictive distribution over samples from a posterior distribution (Colab here). TFP 0.25 with JAX backend.
My (mre and therefore contrived) model specification is
@tfd.JointDistributionCoroutineAutoBatched
def model_autobatched():
theta = yield tfd.Normal(loc=0., scale=1., name="theta")
yield tfd.Normal(loc=theta, scale=0.1, name="y")i.e. a Normally-distributed observation model with Normally-distributed mean. To compute the posterior predictive distribution, I wish to sample the y component conditional on a vector of theta samples.
theta_samples = np.arange(5.)
model_autobatched.sample(theta=theta_samples, seed=jax.random.key(0))giving
StructTuple(
theta=Array([0., 1., 2., 3., 4.], dtype=float32),
y=Array([0.06215769, 1.0621576 , 2.0621576 , 3.0621576 , 4.0621576 ], dtype=float32)
)
Oh dear, we notice that y - theta = constant. This seems to suggest that a single PRNG key is being used for each draw of y given the sample from theta.
Moreover, this approach fails entirely if sample_distributions is called.
model_autobatched.sample_distributions(theta=theta_samples, seed=jax.random.key(0))ValueError: Attempt to convert a value (<object object at 0x7a53561590d0>) with an unsupported type (<class 'object'>) to a Tensor.
As a workaround, we could use the older JointDistributionCoroutine with Root annotation which works as desired (see Colab)
[edit] actually, JDCoroutine/Root only works because the whole theta vector is passed to y's constructor, not vectorisation over the whole model.
Do we have a bug or a feature, I wonder?
Regards,
Chris