Skip to content

Commit

Permalink
Create utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bei-Jin authored Aug 20, 2018
1 parent 957a79b commit e0beadf
Showing 1 changed file with 159 additions and 0 deletions.
159 changes: 159 additions & 0 deletions code/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
Some codes from https://github.com/Newmu/dcgan_code
"""
import cv2
import random
import imageio
import scipy.misc
import numpy as np


def transform(image):
return image/127.5 - 1.


def inverse_transform(images):
return (images+1.)/2.


def save_images(images, size, image_path):
return imsave(inverse_transform(images)*255., size, image_path)


def merge(images, size):
h, w = images.shape[1], images.shape[2]
img = np.zeros((h * size[0], w * size[1], 3))

for idx, image in enumerate(images):
i = idx % size[1]
j = idx / size[1]
img[j*h:j*h+h, i*w:i*w+w, :] = image

return img


def imsave(images, size, path):
return scipy.misc.imsave(path, merge(images, size))


def get_minibatches_idx(n, minibatch_size, shuffle=False):
"""
Used to shuffle the dataset at each iteration.
"""

idx_list = np.arange(n, dtype="int32")

if shuffle:
random.shuffle(idx_list)

minibatches = []
minibatch_start = 0
for i in range(n // minibatch_size):
minibatches.append(idx_list[minibatch_start:
minibatch_start + minibatch_size])
minibatch_start += minibatch_size

if (minibatch_start != n):
# Make a minibatch out of what is left
minibatches.append(idx_list[minibatch_start:])

return zip(range(len(minibatches)), minibatches)


def draw_frame(img, is_input):
if img.shape[2] == 1:
img = np.repeat(img, [3], axis=2)

if is_input:
img[:2,:,0] = img[:2,:,2] = 0
img[:,:2,0] = img[:,:2,2] = 0
img[-2:,:,0] = img[-2:,:,2] = 0
img[:,-2:,0] = img[:,-2:,2] = 0
img[:2,:,1] = 255
img[:,:2,1] = 255
img[-2:,:,1] = 255
img[:,-2:,1] = 255
else:
img[:2,:,0] = img[:2,:,1] = 0
img[:,:2,0] = img[:,:2,2] = 0
img[-2:,:,0] = img[-2:,:,1] = 0
img[:,-2:,0] = img[:,-2:,1] = 0
img[:2,:,2] = 255
img[:,:2,2] = 255
img[-2:,:,2] = 255
img[:,-2:,2] = 255

return img


def load_kth_data(f_name, data_path, image_size, K, T):
flip = np.random.binomial(1,.5,1)[0]
tokens = f_name.split()
vid_path = data_path + tokens[0] + "_uncomp.avi"
vid = imageio.get_reader(vid_path,"ffmpeg")
low = int(tokens[1])
high = np.min([int(tokens[2]),vid.get_length()])-K-T+1
if low == high:
stidx = 0
else:
if low >= high: print(vid_path)
stidx = np.random.randint(low=low, high=high)
seq = np.zeros((image_size, image_size, K+T, 1), dtype="float32")
for t in xrange(K+T):
img = cv2.cvtColor(cv2.resize(vid.get_data(stidx+t),
(image_size,image_size)),
cv2.COLOR_RGB2GRAY)
seq[:,:,t] = transform(img[:,:,None])

if flip == 1:
seq = seq[:,::-1]

diff = np.zeros((image_size, image_size, K-1, 1), dtype="float32")
for t in xrange(1,K):
prev = inverse_transform(seq[:,:,t-1])
next = inverse_transform(seq[:,:,t])
diff[:,:,t-1] = next.astype("float32")-prev.astype("float32")

return seq, diff


def load_s1m_data(f_name, data_path, trainlist, K, T):
flip = np.random.binomial(1,.5,1)[0]
vid_path = data_path + f_name
img_size = [240,320]

while True:
try:
vid = imageio.get_reader(vid_path,"ffmpeg")
low = 1
high = vid.get_length()-K-T+1
if low == high:
stidx = 0
else:
stidx = np.random.randint(low=low, high=high)
seq = np.zeros((img_size[0], img_size[1], K+T, 3),
dtype="float32")
for t in xrange(K+T):
img = cv2.resize(vid.get_data(stidx+t),
(img_size[1],img_size[0]))[:,:,::-1]
seq[:,:,t] = transform(img)

if flip == 1:
seq = seq[:,::-1]

diff = np.zeros((img_size[0], img_size[1], K-1, 1),
dtype="float32")
for t in xrange(1,K):
prev = inverse_transform(seq[:,:,t-1])*255
prev = cv2.cvtColor(prev.astype("uint8"),cv2.COLOR_BGR2GRAY)
next = inverse_transform(seq[:,:,t])*255
next = cv2.cvtColor(next.astype("uint8"),cv2.COLOR_BGR2GRAY)
diff[:,:,t-1,0] = (next.astype("float32")-prev.astype("float32"))/255.
break
except Exception:
# In case the current video is bad load a random one
rep_idx = np.random.randint(low=0, high=len(trainlist))
f_name = trainlist[rep_idx]
vid_path = data_path + f_name

return seq, diff

0 comments on commit e0beadf

Please sign in to comment.