-
Notifications
You must be signed in to change notification settings - Fork 237
Transformer Block Implementation gives NotImplementedError #203
Description
Hi,
I implemented a basic transformer block with residual connections and am getting the following error:
NotImplementedError: `FanInSum` is only implemented for the case where all input layers guaranteed to be mean-zero Gaussian, i.e. having all `is_gaussian` set to `True`, got [True, False].
It appears that it's due to stax.Identity()
Here is the implementation:
def FeedForwardNetwork(hidden_dim, output_dim):
return stax.serial(stax.Dense(hidden_dim), stax.Relu(),
stax.Dense(output_dim)
)
AttnBlock = stax.serial(stax.FanOut(2),
stax.parallel(
stax.serial(
stax.GlobalSelfAttention(
n_chan_out=1,
n_chan_key=1,
n_chan_val=1,
pos_emb_type='SUM',
W_pos_emb_std=1,
# pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
attention_mechanism='SOFTMAX',
linear_scaling=True,
n_heads=1)
),
stax.Identity()
),
stax.FanInSum()
)
def TransformerBlock(ff_dim, d_model):
return stax.serial(AttnBlock,
stax.LayerNorm(),
stax.FanOut(2),
stax.parallel(
FeedForwardNetwork(ff_dim, d_model),
stax.Identity()
),
stax.FanInSum(),
stax.LayerNorm()
)
def Transformer(num_layers,ff_dim, d_model):
layers = []
for _ in range(num_layers):
layers.append(TransformerBlock(ff_dim, d_model))
layers.append(stax.Dense(out_dim=1))
return stax.serial(*layers)
num_layers = 1
ff_dim = 128
d_model = 256
init_fn, apply_fn, kernel_fn = Transformer(num_layers, ff_dim, d_model)
And then taking the example data from the cookbook:
key = random.PRNGKey(10)
train_points = 5
test_points = 50
noise_scale = 1e-1
target_fn = lambda x: jnp.sin(x)
key, x_key, y_key = random.split(key, 3)
train_xs = random.uniform(x_key, (train_points, 1), minval=-jnp.pi, maxval=jnp.pi)
train_ys = target_fn(train_xs)
train_ys += noise_scale * random.normal(y_key, (train_points, 1))
train = (train_xs, train_ys)
test_xs = jnp.linspace(-jnp.pi, jnp.pi, test_points)
test_xs = jnp.reshape(test_xs, (test_points, 1))
test_ys = target_fn(test_xs)
test = (test_xs, test_ys)
apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnames='get')
kernel = kernel_fn(test_xs, test_xs, 'nngp')
std_dev = jnp.sqrt(jnp.diag(kernel))
where the error occurs in the kernel_fn calculation.
What is odd is that the ResBlock works in the cookbook:
ResBlock = stax.serial(
stax.FanOut(2),
stax.parallel(
stax.serial(
stax.Erf(),
stax.Dense(512, W_std=1.1, b_std=0),
),
stax.Identity()
,
stax.FanInSum()
)
And it appears that with linear_scaling=True that the is_gaussian=True from this line:
https://github.com/google/neural-tangents/blob/c17e770bb74f1771da7be4a69fabfa68b6078960/neural_tangents/_src/stax/linear.py#L2464C14-L2468C39
Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.