The Geometry of Deep Generative Models
Here, we study how the geometry of deep generative models (DGMs) can inform our understanding of phenomena like the likelihood out-of-distribution paradox. In tandem and as a supplement to these topics, we also study algorithms for local intrinsic dimension (LID) estimation of datapoints. Please navigate to our repository and run the following steps to get started.
1 Installation
We use a conda environment for this project. To create the environment, run:
# this will create an environment named dgm_geometry:
conda env create -f env.yaml
# this will activate the environment:
conda activate dgm_geometry
Finally, download all the checkpoints, resources, and appropriate environment variables using the following:
python scripts/download_resources.py
You may choose to skip this stage if you want to train your own models, but it is recommended as some of the notebooks use them.
2 Training a Deep Generative Model
Most of the capabilities in the codebase involve using the training script, we use Pytorch Lightning for training and lightning callbacks for monitoring the behaviour and properties of the manifold induced by the generative model. Even when no training is involved, we use the training script but load checkpoints and set the epoch count to zero.
Training involves running scripts/train.py
alongside a dataset and an experiment configuraton. To get started, you can run the following examples for training flows or diffusions on image datasets:
# to train a greyscale diffusion, run the following! You can for example replace the dataset argument with mnist or fmnist
python scripts/train.py dataset=<grayscale-data> +experiment=train_diffusion_greyscale
# to train an RGB diffusion, run the following! You can for example replace the dataset argument with cifar10
python scripts/train.py dataset=<rgb-data> +experiment=train_diffusion_rgb
# to train a greyscale flow, run the following! You can for example replace the dataset argument with mnist or fmnist
python scripts/train.py dataset=<grayscale-data> +experiment=train_flow_greyscale
# to train an RGB flow, run the following! You can for example replace the dataset argument with cifar10
python scripts/train.py dataset=<rgb-data> +experiment=train_flow_rgb
For example:
python scripts/train.py dataset=mnist +experiment=train_diffusion_greyscale
3 Our Work
This repository contains code for reproducing the results in our papers “A geometric view of data complexity” (Kamkari, Ross, Hosseinzadeh, et al. 2024) and a geometric explanation of the likelihood OOD detection paradox (Kamkari, Ross, Cresswell, et al. 2024). You may cite our repository as follows:
@misc{dgm_geometry_github,
author = {Hamidreza Kamkari, Brendan Leigh Ross, Jesse C. Cresswell, Gabriel Loaiza-Ganem},
title = {DGM Geometry},
year = {2024},
howpublished = {\url{https://github.com/blross/dgm-geometry}},
note = {GitHub repository},
}