Hello,
I implemented a brutally simple infinite-width model, calling the kernel_fn with a batch of a single vector.
When I run this on CPU, I don't run into any exorbitant memory issues.
However, when I run this on an A100 GPU, it allocates just under 60GB after calling this tiny calculation!
This also happens on the GPU only when infinite-width CNNs are used on image datasets (CIFAR, MNIST, etc.)
Does anyone know what could be causing this to happen?
import numpy as np
import neural_tangents as nt
print(jax.devices())
def linear_model():
return nt.stax.serial(
nt.stax.Dense(512), nt.stax.Relu(),
nt.stax.Dense(512), nt.stax.Relu(),
nt.stax.Dense(1)
)
init_fn, apply_fn, kernel_fn = linear_model()
total = 1
X = np.ones((total, 200), dtype=np.float32)
!nvidia-smi
ntk = kernel_fn(X, None, 'ntk')
!nvidia-smi
print(ntk)
Usage from first SMI call: 426MiB / 81920MiB
Usage from second SMI call: 61352MiB / 81920MiB