Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Video generation fix, Checkpoint restore on training and some code cleanup. #247

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
120 changes: 73 additions & 47 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,53 @@
from __future__ import print_function
import sys
sys.path.insert(0, 'src')
import transform, numpy as np, vgg, pdb, os
import scipy.misc
import tensorflow as tf
from utils import save_img, get_img, exists, list_files

import os
import tempfile
from argparse import ArgumentParser
from collections import defaultdict
import time
import json
import subprocess
import numpy
from moviepy.video.io.VideoFileClip import VideoFileClip

import moviepy.video.io.ffmpeg_tools as ffmpeg_tools
import moviepy.video.io.ffmpeg_writer as ffmpeg_writer
import numpy as np
import tensorflow as tf
from moviepy.video.io.VideoFileClip import VideoFileClip

import src.transform
from src.utils import save_img, get_img, exists, list_files

BATCH_SIZE = 4
DEVICE = '/gpu:0'


def ffwd_video(path_in, path_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
video_clip = VideoFileClip(path_in, audio=False)
video_writer = ffmpeg_writer.FFMPEG_VideoWriter(path_out, video_clip.size, video_clip.fps, codec="libx264",
preset="medium", bitrate="2000k",
audiofile=path_in, threads=None,
ffmpeg_params=None)

# Create a temporary file to store the audio.
fp = tempfile.NamedTemporaryFile(suffix='.aac')
temp_audio_file_name = fp.name
fp.close()

# Create a temporary file to store the video.
fp = tempfile.NamedTemporaryFile(suffix='.mp4')
temp_video_file_name = fp.name
fp.close()

# Extract the audio.
ffmpeg_tools.ffmpeg_extract_audio(path_in, temp_audio_file_name)

video_writer = ffmpeg_writer.FFMPEG_VideoWriter(temp_video_file_name, video_clip.size, video_clip.fps,
codec="libx264", preset="medium", audiofile=None, threads=None,
ffmpeg_params=["-b:v", "2000k"])

g = tf.Graph()
soft_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
soft_config.gpu_options.allow_growth = True
with g.as_default(), g.device(device_t), \
tf.compat.v1.Session(config=soft_config) as sess:
tf.compat.v1.Session(config=soft_config) as sess:
batch_shape = (batch_size, video_clip.size[1], video_clip.size[0], 3)
img_placeholder = tf.compat.v1.placeholder(tf.float32, shape=batch_shape,
name='img_placeholder')
name='img_placeholder')

preds = transform.net(img_placeholder)
preds = src.transform.net(img_placeholder)
saver = tf.compat.v1.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
Expand Down Expand Up @@ -67,6 +80,13 @@ def style_and_write(count):

video_writer.close()

# Merge audio and video
ffmpeg_tools.ffmpeg_merge_video_audio(temp_video_file_name, temp_audio_file_name, path_out)

# Delete temporary files
os.remove(temp_video_file_name)
os.remove(temp_audio_file_name)


# get img_shape
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
Expand All @@ -85,12 +105,12 @@ def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
soft_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
soft_config.gpu_options.allow_growth = True
with g.as_default(), g.device(device_t), \
tf.compat.v1.Session(config=soft_config) as sess:
tf.compat.v1.Session(config=soft_config) as sess:
batch_shape = (batch_size,) + img_shape
img_placeholder = tf.compat.v1.placeholder(tf.float32, shape=batch_shape,
name='img_placeholder')
name='img_placeholder')

preds = transform.net(img_placeholder)
preds = src.transform.net(img_placeholder)
saver = tf.compat.v1.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
Expand All @@ -101,38 +121,40 @@ def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
else:
saver.restore(sess, checkpoint_dir)

num_iters = int(len(paths_out)/batch_size)
num_iters = int(len(paths_out) / batch_size)
for i in range(num_iters):
pos = i * batch_size
curr_batch_out = paths_out[pos:pos+batch_size]
curr_batch_out = paths_out[pos:pos + batch_size]
if is_paths:
curr_batch_in = data_in[pos:pos+batch_size]
curr_batch_in = data_in[pos:pos + batch_size]
X = np.zeros(batch_shape, dtype=np.float32)
for j, path_in in enumerate(curr_batch_in):
img = get_img(path_in)
assert img.shape == img_shape, \
'Images have different dimensions. ' + \
'Images have different dimensions. ' + \
'Resize images or use --allow-different-dimensions.'
X[j] = img
else:
X = data_in[pos:pos+batch_size]
X = data_in[pos:pos + batch_size]

_preds = sess.run(preds, feed_dict={img_placeholder:X})
_preds = sess.run(preds, feed_dict={img_placeholder: X})
for j, path_out in enumerate(curr_batch_out):
save_img(path_out, _preds[j])
remaining_in = data_in[num_iters*batch_size:]
remaining_out = paths_out[num_iters*batch_size:]

remaining_in = data_in[num_iters * batch_size:]
remaining_out = paths_out[num_iters * batch_size:]
if len(remaining_in) > 0:
ffwd(remaining_in, remaining_out, checkpoint_dir,
device_t=device_t, batch_size=1)
ffwd(remaining_in, remaining_out, checkpoint_dir,
device_t=device_t, batch_size=1)


def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):
paths_in, paths_out = [in_path], [out_path]
ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)

def ffwd_different_dimensions(in_path, out_path, checkpoint_dir,
device_t=DEVICE, batch_size=4):

def ffwd_different_dimensions(in_path, out_path, checkpoint_dir,
device_t=DEVICE, batch_size=4):
in_path_of_shape = defaultdict(list)
out_path_of_shape = defaultdict(list)
for i in range(len(in_path)):
Expand All @@ -143,8 +165,9 @@ def ffwd_different_dimensions(in_path, out_path, checkpoint_dir,
out_path_of_shape[shape].append(out_image)
for shape in in_path_of_shape:
print('Processing images of shape %s' % shape)
ffwd(in_path_of_shape[shape], out_path_of_shape[shape],
checkpoint_dir, device_t, batch_size)
ffwd(in_path_of_shape[shape], out_path_of_shape[shape],
checkpoint_dir, device_t, batch_size)


def build_parser():
parser = ArgumentParser()
Expand All @@ -154,7 +177,7 @@ def build_parser():
metavar='CHECKPOINT', required=True)

parser.add_argument('--in-path', type=str,
dest='in_path',help='dir or file to transform',
dest='in_path', help='dir or file to transform',
metavar='IN_PATH', required=True)

help_out = 'destination (dir or file) of transformed file or files'
Expand All @@ -163,26 +186,28 @@ def build_parser():
required=True)

parser.add_argument('--device', type=str,
dest='device',help='device to perform compute on',
dest='device', help='device to perform compute on',
metavar='DEVICE', default=DEVICE)

parser.add_argument('--batch-size', type=int,
dest='batch_size',help='batch size for feedforwarding',
dest='batch_size', help='batch size for feedforwarding',
metavar='BATCH_SIZE', default=BATCH_SIZE)

parser.add_argument('--allow-different-dimensions', action='store_true',
dest='allow_different_dimensions',
dest='allow_different_dimensions',
help='allow different image dimensions')

return parser


def check_opts(opts):
exists(opts.checkpoint_dir, 'Checkpoint not found!')
exists(opts.in_path, 'In path not found!')
if os.path.isdir(opts.out_path):
exists(opts.out_path, 'out dir not found!')
assert opts.batch_size > 0


def main():
parser = build_parser()
opts = parser.parse_args()
Expand All @@ -191,22 +216,23 @@ def main():
if not os.path.isdir(opts.in_path):
if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path):
out_path = \
os.path.join(opts.out_path,os.path.basename(opts.in_path))
os.path.join(opts.out_path, os.path.basename(opts.in_path))
else:
out_path = opts.out_path

ffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir,
device=opts.device)
else:
files = list_files(opts.in_path)
full_in = [os.path.join(opts.in_path,x) for x in files]
full_out = [os.path.join(opts.out_path,x) for x in files]
full_in = [os.path.join(opts.in_path, x) for x in files]
full_out = [os.path.join(opts.out_path, x) for x in files]
if opts.allow_different_dimensions:
ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir,
device_t=opts.device, batch_size=opts.batch_size)
else :
ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir,
device_t=opts.device, batch_size=opts.batch_size)
else:
ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device,
batch_size=opts.batch_size)
batch_size=opts.batch_size)


if __name__ == '__main__':
main()
Loading