Source code for simplegan.datasets.load_lsun

import tensorflow as tf
import cv2
from tqdm import tqdm
import tensorflow_datasets as tfds
import numpy as np


__all__ = ["load_lsun"]


[docs]def load_lsun(info=False, img_shape=(64, 64)): r"""Loads the `LSUN <https://www.yf.io/p/lsun>`_ training data without labels - used in DCGAN Args: info (bool, optional): to get info of the dataset loaded. Defaults to ``False`` img_shape (int, tuple, optional): shape of the image to be returned. Defaults to ``(64, 64)`` Return: a numpy array of shape according to img_shape parameter """ assert len(img_shape) == 2 and isinstance( img_shape, tuple ), "img_shape must be a tuple of size 2" if info: ds_train, info = tfds.load( name="lsun", split="train", shuffle_files=True, with_info=info ) else: ds_train = tfds.load(name="lsun", split="train", shuffle_files=True, with_info=info) train_data = [] with tqdm(total=100, desc="preparing dataset") as pbar: for i, data in enumerate(ds_train): img = data["image"].numpy() img = cv2.resize(img, img_shape, interpolation=cv2.INTER_AREA) train_data.append(img) if i % 1681 == 0: pbar.update(1) train_data = np.array(train_data).astype("float32") if info: return train_data, info return train_data