LatentAugment

LatentAugment: Data Augmentation via Guided manipulation of GANs Latent Space

This repository contains the official PyTorch implementation of LatentAugment, a Data Augmentation (DA) policy that steers the Generative Adversarial Network (GAN) latent space to increase the diversity and quality of generated samples.

LatentAugment Features

Installation

Clone the repository

git clone
cd LatentAugment

Install dependencies

Using virtualenv

1) First create a Python virtual environment (optional) using the following command:

python -m venv latentaugment

2) Activate the virtual environment using the following command:

source latentaugment/bin/activate

3) Install the dependencies using the following command:

pip install -r requirements.txt

Using Conda

For Conda users, you can create a new Conda environment using the following command:

conda env create -f environment.yml

Then, activate the environment using the following command:

conda activate latentaugment

The code was tested with Python 3.9.12, PyTorch 1.9.1, CUDA 11.1 and Ubuntu 22.04.2 LTS. For more informations about the requirements, please check the requirements.txt or environment.yml files. All the experiments used a single NVIDIA RTX A5000 GPU.

Usage

Prerequisites

In our experiments we used StyleGAN2 (SG2) and the inversion procedure from in the official SG2 repository stylegan2-ada-pytorch. You are free to use whatever GAN or inversion procedure you want.

To use LatentAugment in your downstream applications please follow the steps below:

from augments import create_augment
from data import create_dataset

from options.aug_options import AugOptions

opt = AugOptions().parse()  # get training options

dataset = create_dataset(opt)  # create the dataset object given opt.dataset_mode and other options
augment = create_augment(opt)  # create the augment object given opt.aug and other options

for i, data in enumerate(dataset):  # loop on training data
    
        # Perform augmentation.
        augment.set_input(data) # Set input for augmentation.
        augment.forward() # Perform the augmentation.
        data_aug = augment.get_output() # Get output from augmentation.
        
        # Train the downstream model.
        # ...

Results

We use UMAP dimensionality reduction to visualise the behaviour of LatentAugment and SG2 synthetic samples in relation to real image latent codes. We initially fit UMAP using the latent codes of real samples, which are depicted as blue stars. Then, we project the latent codes of the real, LatentAugment, and SG2 samples onto this space; LatentAugment samples are represented by green circles and SG2 samples by white triangles

We observe that SG2 samples only cover a small portion of the real manifold, while LatentAugment samples cover all the real manifold. Moreover, LatentAugment samples do not overlap the real samples ensuring diversity. Yet, these samples are near the real latent codes, suggesting high-quality generation.

For a complete visualisation of the UMAP manifold, download the interactive plot created using Bokeh. We suggest to open it on your browser.

Citation

If you use this code in your research, please cite our paper: LatentAugment: Data Augmentation via Guided manipulation of GANs Latent Space

Acknowledgements

This code is based on the following repositories:

Contact for Issues

If you have any questions or if you are just interested in having a virtual coffee about Generative AI, please don’t hesitate to reach out to me at: l.tronchin@unicampus.it.

May be the AI be with you!

License

This code is released under the MIT License.