|
87 | 87 | }, |
88 | 88 | { |
89 | 89 | "cell_type": "code", |
90 | | - "execution_count": null, |
| 90 | + "execution_count": 1, |
91 | 91 | "id": "09300f37", |
92 | 92 | "metadata": {}, |
93 | | - "outputs": [], |
| 93 | + "outputs": [ |
| 94 | + { |
| 95 | + "name": "stdout", |
| 96 | + "output_type": "stream", |
| 97 | + "text": [ |
| 98 | + "PyTorch 2.3.0.post100 | device: cpu\n", |
| 99 | + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", |
| 100 | + "Failed to download (trying next):\n", |
| 101 | + "HTTP Error 404: Not Found\n", |
| 102 | + "\n", |
| 103 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", |
| 104 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz\n" |
| 105 | + ] |
| 106 | + }, |
| 107 | + { |
| 108 | + "name": "stderr", |
| 109 | + "output_type": "stream", |
| 110 | + "text": [ |
| 111 | + "100%|███████████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:01<00:00, 8740496.78it/s]\n" |
| 112 | + ] |
| 113 | + }, |
| 114 | + { |
| 115 | + "name": "stdout", |
| 116 | + "output_type": "stream", |
| 117 | + "text": [ |
| 118 | + "Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw\n", |
| 119 | + "\n", |
| 120 | + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", |
| 121 | + "Failed to download (trying next):\n", |
| 122 | + "HTTP Error 404: Not Found\n", |
| 123 | + "\n", |
| 124 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", |
| 125 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz\n" |
| 126 | + ] |
| 127 | + }, |
| 128 | + { |
| 129 | + "name": "stderr", |
| 130 | + "output_type": "stream", |
| 131 | + "text": [ |
| 132 | + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 259427.32it/s]\n" |
| 133 | + ] |
| 134 | + }, |
| 135 | + { |
| 136 | + "name": "stdout", |
| 137 | + "output_type": "stream", |
| 138 | + "text": [ |
| 139 | + "Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw\n", |
| 140 | + "\n", |
| 141 | + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", |
| 142 | + "Failed to download (trying next):\n", |
| 143 | + "HTTP Error 404: Not Found\n", |
| 144 | + "\n", |
| 145 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", |
| 146 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz\n" |
| 147 | + ] |
| 148 | + }, |
| 149 | + { |
| 150 | + "name": "stderr", |
| 151 | + "output_type": "stream", |
| 152 | + "text": [ |
| 153 | + "100%|███████████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 2362968.53it/s]\n" |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | + "name": "stdout", |
| 158 | + "output_type": "stream", |
| 159 | + "text": [ |
| 160 | + "Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw\n", |
| 161 | + "\n", |
| 162 | + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", |
| 163 | + "Failed to download (trying next):\n", |
| 164 | + "HTTP Error 404: Not Found\n", |
| 165 | + "\n", |
| 166 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n", |
| 167 | + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz\n" |
| 168 | + ] |
| 169 | + }, |
| 170 | + { |
| 171 | + "name": "stderr", |
| 172 | + "output_type": "stream", |
| 173 | + "text": [ |
| 174 | + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 2405673.54it/s]" |
| 175 | + ] |
| 176 | + }, |
| 177 | + { |
| 178 | + "name": "stdout", |
| 179 | + "output_type": "stream", |
| 180 | + "text": [ |
| 181 | + "Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw\n", |
| 182 | + "\n", |
| 183 | + "Train: 60000 | Test: 10000\n", |
| 184 | + "Image shape: torch.Size([1, 28, 28])\n" |
| 185 | + ] |
| 186 | + }, |
| 187 | + { |
| 188 | + "name": "stderr", |
| 189 | + "output_type": "stream", |
| 190 | + "text": [ |
| 191 | + "\n" |
| 192 | + ] |
| 193 | + } |
| 194 | + ], |
94 | 195 | "source": [ |
95 | 196 | "# ── Standard imports ──────────────────────────────────────────────────────\n", |
96 | 197 | "import math, time, os\n", |
|
149 | 250 | }, |
150 | 251 | { |
151 | 252 | "cell_type": "code", |
152 | | - "execution_count": null, |
| 253 | + "execution_count": 2, |
153 | 254 | "id": "1565864d", |
154 | 255 | "metadata": {}, |
155 | | - "outputs": [], |
| 256 | + "outputs": [ |
| 257 | + { |
| 258 | + "name": "stdout", |
| 259 | + "output_type": "stream", |
| 260 | + "text": [ |
| 261 | + "VAE parameters: 258,025\n", |
| 262 | + "VAE(\n", |
| 263 | + " (encoder): Sequential(\n", |
| 264 | + " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 265 | + " (1): LeakyReLU(negative_slope=0.2)\n", |
| 266 | + " (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 267 | + " (3): LeakyReLU(negative_slope=0.2)\n", |
| 268 | + " (4): Flatten(start_dim=1, end_dim=-1)\n", |
| 269 | + " )\n", |
| 270 | + " (fc_mu): Linear(in_features=3136, out_features=20, bias=True)\n", |
| 271 | + " (fc_log_var): Linear(in_features=3136, out_features=20, bias=True)\n", |
| 272 | + " (fc_dec): Linear(in_features=20, out_features=3136, bias=True)\n", |
| 273 | + " (decoder): Sequential(\n", |
| 274 | + " (0): Unflatten(dim=1, unflattened_size=(64, 7, 7))\n", |
| 275 | + " (1): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 276 | + " (2): LeakyReLU(negative_slope=0.2)\n", |
| 277 | + " (3): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", |
| 278 | + " )\n", |
| 279 | + ")\n" |
| 280 | + ] |
| 281 | + } |
| 282 | + ], |
156 | 283 | "source": [ |
157 | 284 | "# ── VAE Model ─────────────────────────────────────────────────────────────\n", |
158 | 285 | "class VAE(nn.Module):\n", |
|
260 | 387 | "execution_count": null, |
261 | 388 | "id": "6e511305", |
262 | 389 | "metadata": {}, |
263 | | - "outputs": [], |
| 390 | + "outputs": [ |
| 391 | + { |
| 392 | + "name": "stdout", |
| 393 | + "output_type": "stream", |
| 394 | + "text": [ |
| 395 | + "Training VAE on MNIST (20 epochs)...\n", |
| 396 | + "(Uses [0,1]-normalised pixels and Bernoulli decoder)\n" |
| 397 | + ] |
| 398 | + } |
| 399 | + ], |
264 | 400 | "source": [ |
265 | 401 | "def train_vae(model, loader, epochs=20, lr=1e-3, beta=1.0):\n", |
266 | 402 | " \"\"\"\n", |
|
0 commit comments