-
Notifications
You must be signed in to change notification settings - Fork 1
/
MDN.py
112 lines (84 loc) · 3.44 KB
/
MDN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow_probability import distributions as tfd
class MDN:
def __init__(self, labels_dim, z_dim=128, cat_num=5):
self.labels_dim = labels_dim
self.z_dim = z_dim
self.cat_num = cat_num
self.labels_shape = (1, labels_dim)
self.z_shape = (1, z_dim)
def elu_modified(self, x):
return tf.nn.elu(x) + 1
def get_model(self):
l_labels = Input(shape=self.labels_shape)
l_z = Input(shape=self.z_shape)
l_fc1 = Dense(
units = 512,
activation = 'tanh')(l_labels)
l_fc2 = Dense(
units = 1024,
activation = 'tanh')(l_fc1)
l_fc3 = Dense(
units = 512,
activation = 'tanh')(l_fc2)
l_alpha = Dense(
units = self.cat_num,
activation = 'softmax')(l_fc3)
l_mu = Dense(
units = self.cat_num*self.z_dim)(l_fc3)
l_sigma = Dense(
units = self.cat_num*self.z_dim,
activation = self.elu_modified)(l_fc3)
return {'labels': l_labels,
'z': l_z,
'alpha': l_alpha,
'mu': l_mu,
'sigma': l_sigma}
def mdn_loss(self, z, alpha, mu, sigma):
alpha = K.repeat_elements(alpha, self.z_dim, axis=1)
alpha = K.expand_dims(alpha, axis=3)
mu = K.reshape(mu, (tf.shape(mu)[0], self.z_dim, self.cat_num))
mu = K.expand_dims(mu, axis=3)
sigma = K.reshape(sigma, (tf.shape(sigma)[0], self.z_dim, self.cat_num))
sigma = K.expand_dims(sigma, axis=3)
gm = tfd.MixtureSameFamily(
mixture_distribution = tfd.Categorical(probs=alpha),
components_distribution = tfd.Normal(
loc = mu,
scale = sigma))
z = tf.transpose(z, (0, 2, 1))
return tf.reduce_mean(-gm.log_prob(z))
def train(self, learning_rate, batch_size, epoch_num, labels_train, z_train):
model = self.get_model()
labels = model['labels']
z = model['z']
alpha = model['alpha']
mu = model['mu']
sigma = model['sigma']
model_train = Model([labels, z], [alpha, mu, sigma])
model_train.add_loss(self.mdn_loss(z, alpha, mu, sigma))
adam = Adam(lr=learning_rate)
model_train.compile(optimizer=adam)
model_train.fit(
[labels_train, z_train],
batch_size = batch_size,
epochs = epoch_num,
validation_data = ([labels_train, z_train], None)
)
return model_train
def test(self, labels_test, weights_dir):
model = self.get_model()
model_test = Model(model['labels'], [model['alpha'], model['mu'], model['sigma']])
model_test.load_weights(weights_dir)
predictions = model_test.predict(labels_test)
predictions_alpha = predictions[0]
predictions_mu = predictions[1]
predictions_sigma = predictions[2]
predictions_mu = np.reshape(predictions_mu, (np.shape(predictions_mu)[0], self.z_dim, self.cat_num))
predictions_sigma = np.reshape(predictions_sigma, (np.shape(predictions_mu)[0], self.z_dim, self.cat_num))
return predictions_alpha, predictions_mu, predictions_sigma