submit project
parent
16f0403c02
commit
f7041fd8fa
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -0,0 +1,239 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
import zipfile
|
||||||
|
import gzip
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _read32(bytestream):
|
||||||
|
"""
|
||||||
|
Read 32-bit integer from bytesteam
|
||||||
|
:param bytestream: A bytestream
|
||||||
|
:return: 32-bit integer
|
||||||
|
"""
|
||||||
|
dt = np.dtype(np.uint32).newbyteorder('>')
|
||||||
|
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _unzip(save_path, _, database_name, data_path):
|
||||||
|
"""
|
||||||
|
Unzip wrapper with the same interface as _ungzip
|
||||||
|
:param save_path: The path of the gzip files
|
||||||
|
:param database_name: Name of database
|
||||||
|
:param data_path: Path to extract to
|
||||||
|
:param _: HACK - Used to have to same interface as _ungzip
|
||||||
|
"""
|
||||||
|
print('Extracting {}...'.format(database_name))
|
||||||
|
with zipfile.ZipFile(save_path) as zf:
|
||||||
|
zf.extractall(data_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _ungzip(save_path, extract_path, database_name, _):
|
||||||
|
"""
|
||||||
|
Unzip a gzip file and extract it to extract_path
|
||||||
|
:param save_path: The path of the gzip files
|
||||||
|
:param extract_path: The location to extract the data to
|
||||||
|
:param database_name: Name of database
|
||||||
|
:param _: HACK - Used to have to same interface as _unzip
|
||||||
|
"""
|
||||||
|
# Get data from save_path
|
||||||
|
with open(save_path, 'rb') as f:
|
||||||
|
with gzip.GzipFile(fileobj=f) as bytestream:
|
||||||
|
magic = _read32(bytestream)
|
||||||
|
if magic != 2051:
|
||||||
|
raise ValueError('Invalid magic number {} in file: {}'.format(magic, f.name))
|
||||||
|
num_images = _read32(bytestream)
|
||||||
|
rows = _read32(bytestream)
|
||||||
|
cols = _read32(bytestream)
|
||||||
|
buf = bytestream.read(rows * cols * num_images)
|
||||||
|
data = np.frombuffer(buf, dtype=np.uint8)
|
||||||
|
data = data.reshape(num_images, rows, cols)
|
||||||
|
|
||||||
|
# Save data to extract_path
|
||||||
|
for image_i, image in enumerate(
|
||||||
|
tqdm(data, unit='File', unit_scale=True, miniters=1, desc='Extracting {}'.format(database_name))):
|
||||||
|
Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_image(image_path, width, height, mode):
|
||||||
|
"""
|
||||||
|
Read image from image_path
|
||||||
|
:param image_path: Path of image
|
||||||
|
:param width: Width of image
|
||||||
|
:param height: Height of image
|
||||||
|
:param mode: Mode of image
|
||||||
|
:return: Image data
|
||||||
|
"""
|
||||||
|
image = Image.open(image_path)
|
||||||
|
|
||||||
|
if image.size != (width, height): # HACK - Check if image is from the CELEBA dataset
|
||||||
|
# Remove most pixels that aren't part of a face
|
||||||
|
face_width = face_height = 108
|
||||||
|
j = (image.size[0] - face_width) // 2
|
||||||
|
i = (image.size[1] - face_height) // 2
|
||||||
|
image = image.crop([j, i, j + face_width, i + face_height])
|
||||||
|
image = image.resize([width, height], Image.BILINEAR)
|
||||||
|
|
||||||
|
return np.array(image.convert(mode))
|
||||||
|
|
||||||
|
|
||||||
|
def get_batch(image_files, width, height, mode):
|
||||||
|
data_batch = np.array(
|
||||||
|
[get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32)
|
||||||
|
|
||||||
|
# Make sure the images are in 4 dimensions
|
||||||
|
if len(data_batch.shape) < 4:
|
||||||
|
data_batch = data_batch.reshape(data_batch.shape + (1,))
|
||||||
|
|
||||||
|
return data_batch
|
||||||
|
|
||||||
|
|
||||||
|
def images_square_grid(images, mode):
|
||||||
|
"""
|
||||||
|
Save images as a square grid
|
||||||
|
:param images: Images to be used for the grid
|
||||||
|
:param mode: The mode to use for images
|
||||||
|
:return: Image of images in a square grid
|
||||||
|
"""
|
||||||
|
# Get maximum size for square grid of images
|
||||||
|
save_size = math.floor(np.sqrt(images.shape[0]))
|
||||||
|
|
||||||
|
# Scale to 0-255
|
||||||
|
images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
|
||||||
|
|
||||||
|
# Put images in a square arrangement
|
||||||
|
images_in_square = np.reshape(
|
||||||
|
images[:save_size*save_size],
|
||||||
|
(save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
|
||||||
|
if mode == 'L':
|
||||||
|
images_in_square = np.squeeze(images_in_square, 4)
|
||||||
|
|
||||||
|
# Combine images to grid image
|
||||||
|
new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
|
||||||
|
for col_i, col_images in enumerate(images_in_square):
|
||||||
|
for image_i, image in enumerate(col_images):
|
||||||
|
im = Image.fromarray(image, mode)
|
||||||
|
new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
|
||||||
|
|
||||||
|
return new_im
|
||||||
|
|
||||||
|
|
||||||
|
def download_extract(database_name, data_path):
|
||||||
|
"""
|
||||||
|
Download and extract database
|
||||||
|
:param database_name: Database name
|
||||||
|
"""
|
||||||
|
DATASET_CELEBA_NAME = 'celeba'
|
||||||
|
DATASET_MNIST_NAME = 'mnist'
|
||||||
|
|
||||||
|
if database_name == DATASET_CELEBA_NAME:
|
||||||
|
url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip'
|
||||||
|
hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb'
|
||||||
|
extract_path = os.path.join(data_path, 'img_align_celeba')
|
||||||
|
save_path = os.path.join(data_path, 'celeba.zip')
|
||||||
|
extract_fn = _unzip
|
||||||
|
elif database_name == DATASET_MNIST_NAME:
|
||||||
|
url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
|
||||||
|
hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
|
||||||
|
extract_path = os.path.join(data_path, 'mnist')
|
||||||
|
save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz')
|
||||||
|
extract_fn = _ungzip
|
||||||
|
|
||||||
|
if os.path.exists(extract_path):
|
||||||
|
print('Found {} Data'.format(database_name))
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(data_path):
|
||||||
|
os.makedirs(data_path)
|
||||||
|
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
|
||||||
|
urlretrieve(
|
||||||
|
url,
|
||||||
|
save_path,
|
||||||
|
pbar.hook)
|
||||||
|
|
||||||
|
assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
|
||||||
|
'{} file is corrupted. Remove the file and try again.'.format(save_path)
|
||||||
|
|
||||||
|
os.makedirs(extract_path)
|
||||||
|
try:
|
||||||
|
extract_fn(save_path, extract_path, database_name, data_path)
|
||||||
|
except Exception as err:
|
||||||
|
shutil.rmtree(extract_path) # Remove extraction folder if there is an error
|
||||||
|
raise err
|
||||||
|
|
||||||
|
# Remove compressed data
|
||||||
|
os.remove(save_path)
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(object):
|
||||||
|
"""
|
||||||
|
Dataset
|
||||||
|
"""
|
||||||
|
def __init__(self, dataset_name, data_files):
|
||||||
|
"""
|
||||||
|
Initalize the class
|
||||||
|
:param dataset_name: Database name
|
||||||
|
:param data_files: List of files in the database
|
||||||
|
"""
|
||||||
|
DATASET_CELEBA_NAME = 'celeba'
|
||||||
|
DATASET_MNIST_NAME = 'mnist'
|
||||||
|
IMAGE_WIDTH = 28
|
||||||
|
IMAGE_HEIGHT = 28
|
||||||
|
|
||||||
|
if dataset_name == DATASET_CELEBA_NAME:
|
||||||
|
self.image_mode = 'RGB'
|
||||||
|
image_channels = 3
|
||||||
|
|
||||||
|
elif dataset_name == DATASET_MNIST_NAME:
|
||||||
|
self.image_mode = 'L'
|
||||||
|
image_channels = 1
|
||||||
|
|
||||||
|
self.data_files = data_files
|
||||||
|
self.shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, image_channels
|
||||||
|
|
||||||
|
def get_batches(self, batch_size):
|
||||||
|
"""
|
||||||
|
Generate batches
|
||||||
|
:param batch_size: Batch Size
|
||||||
|
:return: Batches of data
|
||||||
|
"""
|
||||||
|
IMAGE_MAX_VALUE = 255
|
||||||
|
|
||||||
|
current_index = 0
|
||||||
|
while current_index + batch_size <= self.shape[0]:
|
||||||
|
data_batch = get_batch(
|
||||||
|
self.data_files[current_index:current_index + batch_size],
|
||||||
|
*self.shape[1:3],
|
||||||
|
self.image_mode)
|
||||||
|
|
||||||
|
current_index += batch_size
|
||||||
|
|
||||||
|
yield data_batch / IMAGE_MAX_VALUE - 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class DLProgress(tqdm):
|
||||||
|
"""
|
||||||
|
Handle Progress Bar while Downloading
|
||||||
|
"""
|
||||||
|
last_block = 0
|
||||||
|
|
||||||
|
def hook(self, block_num=1, block_size=1, total_size=None):
|
||||||
|
"""
|
||||||
|
A hook function that will be called once on establishment of the network connection and
|
||||||
|
once after each block read thereafter.
|
||||||
|
:param block_num: A count of blocks transferred so far
|
||||||
|
:param block_size: Block size in bytes
|
||||||
|
:param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
|
||||||
|
a file size in response to a retrieval request.
|
||||||
|
"""
|
||||||
|
self.total = total_size
|
||||||
|
self.update((block_num - self.last_block) * block_size)
|
||||||
|
self.last_block = block_num
|
@ -0,0 +1,151 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from unittest import mock
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def test_safe(func):
|
||||||
|
"""
|
||||||
|
Isolate tests
|
||||||
|
"""
|
||||||
|
def func_wrapper(*args):
|
||||||
|
with tf.Graph().as_default():
|
||||||
|
result = func(*args)
|
||||||
|
print('Tests Passed')
|
||||||
|
return result
|
||||||
|
|
||||||
|
return func_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_tensor_shape(tensor, shape, display_name):
|
||||||
|
assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name))
|
||||||
|
|
||||||
|
tensor_shape = tensor.get_shape().as_list() if len(shape) else []
|
||||||
|
|
||||||
|
wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape)
|
||||||
|
if cor_dim is not None and ten_dim != cor_dim]
|
||||||
|
assert not wrong_dimension, \
|
||||||
|
'{} has wrong shape. Found {}'.format(display_name, tensor_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_input(tensor, shape, display_name, tf_name=None):
|
||||||
|
assert tensor.op.type == 'Placeholder', \
|
||||||
|
'{} is not a Placeholder.'.format(display_name)
|
||||||
|
|
||||||
|
_assert_tensor_shape(tensor, shape, 'Real Input')
|
||||||
|
|
||||||
|
if tf_name:
|
||||||
|
assert tensor.name == tf_name, \
|
||||||
|
'{} has bad name. Found name {}'.format(display_name, tensor.name)
|
||||||
|
|
||||||
|
|
||||||
|
class TmpMock():
|
||||||
|
"""
|
||||||
|
Mock a attribute. Restore attribute when exiting scope.
|
||||||
|
"""
|
||||||
|
def __init__(self, module, attrib_name):
|
||||||
|
self.original_attrib = deepcopy(getattr(module, attrib_name))
|
||||||
|
setattr(module, attrib_name, mock.MagicMock())
|
||||||
|
self.module = module
|
||||||
|
self.attrib_name = attrib_name
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return getattr(self.module, self.attrib_name)
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
setattr(self.module, self.attrib_name, self.original_attrib)
|
||||||
|
|
||||||
|
|
||||||
|
@test_safe
|
||||||
|
def test_model_inputs(model_inputs):
|
||||||
|
image_width = 28
|
||||||
|
image_height = 28
|
||||||
|
image_channels = 3
|
||||||
|
z_dim = 100
|
||||||
|
input_real, input_z, learn_rate = model_inputs(image_width, image_height, image_channels, z_dim)
|
||||||
|
|
||||||
|
_check_input(input_real, [None, image_width, image_height, image_channels], 'Real Input')
|
||||||
|
_check_input(input_z, [None, z_dim], 'Z Input')
|
||||||
|
_check_input(learn_rate, [], 'Learning Rate')
|
||||||
|
|
||||||
|
|
||||||
|
@test_safe
|
||||||
|
def test_discriminator(discriminator, tf_module):
|
||||||
|
with TmpMock(tf_module, 'variable_scope') as mock_variable_scope:
|
||||||
|
image = tf.placeholder(tf.float32, [None, 28, 28, 3])
|
||||||
|
|
||||||
|
output, logits = discriminator(image)
|
||||||
|
_assert_tensor_shape(output, [None, 1], 'Discriminator Training(reuse=false) output')
|
||||||
|
_assert_tensor_shape(logits, [None, 1], 'Discriminator Training(reuse=false) Logits')
|
||||||
|
assert mock_variable_scope.called,\
|
||||||
|
'tf.variable_scope not called in Discriminator Training(reuse=false)'
|
||||||
|
assert mock_variable_scope.call_args == mock.call('discriminator', reuse=False), \
|
||||||
|
'tf.variable_scope called with wrong arguments in Discriminator Training(reuse=false)'
|
||||||
|
|
||||||
|
mock_variable_scope.reset_mock()
|
||||||
|
|
||||||
|
output_reuse, logits_reuse = discriminator(image, True)
|
||||||
|
_assert_tensor_shape(output_reuse, [None, 1], 'Discriminator Inference(reuse=True) output')
|
||||||
|
_assert_tensor_shape(logits_reuse, [None, 1], 'Discriminator Inference(reuse=True) Logits')
|
||||||
|
assert mock_variable_scope.called, \
|
||||||
|
'tf.variable_scope not called in Discriminator Inference(reuse=True)'
|
||||||
|
assert mock_variable_scope.call_args == mock.call('discriminator', reuse=True), \
|
||||||
|
'tf.variable_scope called with wrong arguments in Discriminator Inference(reuse=True)'
|
||||||
|
|
||||||
|
|
||||||
|
@test_safe
|
||||||
|
def test_generator(generator, tf_module):
|
||||||
|
with TmpMock(tf_module, 'variable_scope') as mock_variable_scope:
|
||||||
|
z = tf.placeholder(tf.float32, [None, 100])
|
||||||
|
out_channel_dim = 5
|
||||||
|
|
||||||
|
output = generator(z, out_channel_dim)
|
||||||
|
_assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=True)')
|
||||||
|
assert mock_variable_scope.called, \
|
||||||
|
'tf.variable_scope not called in Generator Training(reuse=false)'
|
||||||
|
assert mock_variable_scope.call_args == mock.call('generator', reuse=False), \
|
||||||
|
'tf.variable_scope called with wrong arguments in Generator Training(reuse=false)'
|
||||||
|
|
||||||
|
mock_variable_scope.reset_mock()
|
||||||
|
output = generator(z, out_channel_dim, False)
|
||||||
|
_assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=False)')
|
||||||
|
assert mock_variable_scope.called, \
|
||||||
|
'tf.variable_scope not called in Generator Inference(reuse=True)'
|
||||||
|
assert mock_variable_scope.call_args == mock.call('generator', reuse=True), \
|
||||||
|
'tf.variable_scope called with wrong arguments in Generator Inference(reuse=True)'
|
||||||
|
|
||||||
|
|
||||||
|
@test_safe
|
||||||
|
def test_model_loss(model_loss):
|
||||||
|
out_channel_dim = 4
|
||||||
|
input_real = tf.placeholder(tf.float32, [None, 28, 28, out_channel_dim])
|
||||||
|
input_z = tf.placeholder(tf.float32, [None, 100])
|
||||||
|
|
||||||
|
d_loss, g_loss = model_loss(input_real, input_z, out_channel_dim)
|
||||||
|
|
||||||
|
_assert_tensor_shape(d_loss, [], 'Discriminator Loss')
|
||||||
|
_assert_tensor_shape(d_loss, [], 'Generator Loss')
|
||||||
|
|
||||||
|
|
||||||
|
@test_safe
|
||||||
|
def test_model_opt(model_opt, tf_module):
|
||||||
|
with TmpMock(tf_module, 'trainable_variables') as mock_trainable_variables:
|
||||||
|
with tf.variable_scope('discriminator'):
|
||||||
|
discriminator_logits = tf.Variable(tf.zeros([3, 3]))
|
||||||
|
with tf.variable_scope('generator'):
|
||||||
|
generator_logits = tf.Variable(tf.zeros([3, 3]))
|
||||||
|
|
||||||
|
mock_trainable_variables.return_value = [discriminator_logits, generator_logits]
|
||||||
|
d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
|
||||||
|
logits=discriminator_logits,
|
||||||
|
labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]))
|
||||||
|
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
|
||||||
|
logits=generator_logits,
|
||||||
|
labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]))
|
||||||
|
learning_rate = 0.001
|
||||||
|
beta1 = 0.9
|
||||||
|
|
||||||
|
d_train_opt, g_train_opt = model_opt(d_loss, g_loss, learning_rate, beta1)
|
||||||
|
assert mock_trainable_variables.called,\
|
||||||
|
'tf.mock_trainable_variables not called'
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue