Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.
This repository was archived by the owner on May 6, 2025. It is now read-only.

Transformer Block Implementation gives NotImplementedError #203

@esnvidia

Description

@esnvidia

Hi,

I implemented a basic transformer block with residual connections and am getting the following error:

NotImplementedError: `FanInSum` is only implemented for the case where all input layers guaranteed to be mean-zero Gaussian, i.e. having all `is_gaussian` set to `True`, got [True, False].

It appears that it's due to stax.Identity()

Here is the implementation:

def FeedForwardNetwork(hidden_dim, output_dim):
    return stax.serial(stax.Dense(hidden_dim), stax.Relu(),
                       stax.Dense(output_dim)
                      )

AttnBlock = stax.serial(stax.FanOut(2),
                        stax.parallel(
                            stax.serial(
                                stax.GlobalSelfAttention(
                                   n_chan_out=1,
                                   n_chan_key=1,
                                   n_chan_val=1,
                                   pos_emb_type='SUM',
                                   W_pos_emb_std=1,
                                   # pos_emb_decay_fn=lambda d: 1 / (1 + d**2),
                                   attention_mechanism='SOFTMAX',
                                    linear_scaling=True,
                                   n_heads=1)
                            ),
                            stax.Identity()
                        ),
                        stax.FanInSum()
                       )

def TransformerBlock(ff_dim, d_model):
    return stax.serial(AttnBlock,
                       stax.LayerNorm(),
                       stax.FanOut(2),
                       stax.parallel(
                           FeedForwardNetwork(ff_dim, d_model),
                          stax.Identity()
                       ),
                       stax.FanInSum(),
                       stax.LayerNorm()
                      )
def Transformer(num_layers,ff_dim, d_model):
    layers = []
    for _ in range(num_layers):
        layers.append(TransformerBlock(ff_dim, d_model))
    layers.append(stax.Dense(out_dim=1))
    return stax.serial(*layers)

num_layers = 1
ff_dim = 128
d_model = 256

init_fn, apply_fn, kernel_fn = Transformer(num_layers, ff_dim, d_model)

And then taking the example data from the cookbook:

key = random.PRNGKey(10)
train_points = 5
test_points = 50
noise_scale = 1e-1

target_fn = lambda x: jnp.sin(x)

key, x_key, y_key = random.split(key, 3)

train_xs = random.uniform(x_key, (train_points, 1), minval=-jnp.pi, maxval=jnp.pi)

train_ys = target_fn(train_xs)
train_ys += noise_scale * random.normal(y_key, (train_points, 1))
train = (train_xs, train_ys)

test_xs = jnp.linspace(-jnp.pi, jnp.pi, test_points)
test_xs = jnp.reshape(test_xs, (test_points, 1))

test_ys = target_fn(test_xs)
test = (test_xs, test_ys)

apply_fn = jit(apply_fn)
kernel_fn = jit(kernel_fn, static_argnames='get')

kernel = kernel_fn(test_xs, test_xs, 'nngp')
std_dev = jnp.sqrt(jnp.diag(kernel))

where the error occurs in the kernel_fn calculation.

What is odd is that the ResBlock works in the cookbook:

ResBlock = stax.serial(
    stax.FanOut(2),
    stax.parallel(
        stax.serial(
            stax.Erf(),
            stax.Dense(512, W_std=1.1, b_std=0),
        ),
        stax.Identity()
        ,
    stax.FanInSum()
)

And it appears that with linear_scaling=True that the is_gaussian=True from this line:
https://github.com/google/neural-tangents/blob/c17e770bb74f1771da7be4a69fabfa68b6078960/neural_tangents/_src/stax/linear.py#L2464C14-L2468C39

Eventually would like to also include causal masking, and if you have pointers there that would also be great as it is also not clear how to do a upper triangular mask in the infinite width seq len case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions