Skip to content

Generative models for discrete data built on assignment flows

License

Notifications You must be signed in to change notification settings

IPA-HD/generative-af

Repository files navigation

generative-af

This repository contains code associated with Generative Assignment Flows for Representing and Learning Joint Distributions of Discrete Data. It was developed by the Image & Pattern Analysis Group at Heidelberg University.

Runtime Environment

The quickest way to get up and running is by creating a new conda environment gen-af from the provided environment file

conda env create --file environment.yaml

Configuration

Training hyperparameters are specified by YAML configuration files in config/. We use hydra to parse these files hierarchically, which also allows overwriting from the command line.

Training Examples

Binarized MNIST

python train.py data=mnist logging=epochs model=unet training=mnist

Cityscapes Segmentations

python train.py data=cityscapes logging=steps model=unet training=cityscapes

Coupled Binary Variables Toy Distribution

python train.py data=coupled_binary logging=frequent model=dense training=simple

Other Simple Data Distributions

python train.py data=simple data.dataset=pinwheel logging=frequent model=dense training=simple

If data=simple is set, options for data.dataset are pinwheel and gaussian_mixture.

Scaling to many classes

python train.py -m data=num_classes data.num_classes=5,10,20,40,60,80,100,120,140,160 logging=epochs logging.eval_interval_epochs=100 model=cnn training=num_classes

Training artifacts, including model checkpoints, Tensorboard logs and hyperparameters are saved in lightning_logs/.

Cityscapes Segmentation Data

To train a generative model for Cityscapes Segmentations, first download the dataset to a directory of your choice and subsequently run the preprocessing routine

cd data/image
python scale_cityscapes.py /path/to/raw/data 0.125 train

The second argument is a scaling factor for spatial dimensions (preprocessed files will be subsampled by a factor of 8 with interpolation mode PIL.Image.NEAREST). The number of segments will also be reduced, corresponding to the category of segments in the Cityscapes torchvision dataset. Preprocessed Cityscapes segmentation data are saved to data/image/cityscapes/cityscapes_{split}_{scale}.pt.

Citation

@article{Boll:2024ac,
	author = {Boll, B. and Gonzalez-Alvarado, D. and Petra, S. and Schn\"{o}rr, C.},
	journal = {preprint arXiv:2406.04527},
	title = {{Generative Assignment Flows for Representing and Learning Joint Distributions of Discrete Data}},
	year = {2024},
	Url = {https://arxiv.org/abs/2406.04527}
}

About

Generative models for discrete data built on assignment flows

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages