Skip to content

Commit

Permalink
Merge pull request #12 from wingedsheep/feature/routing_transformer_p…
Browse files Browse the repository at this point in the history
…arameters

Added extra configuration options to the routing transformer model.
  • Loading branch information
wingedsheep committed Dec 3, 2021
2 parents c9052f4 + 1f14bc6 commit ec935cb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 21 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## 0.3.2 - 2021-12-03

### Added
- More configuration options for the routing transformer

## 0.3.1 - 2021-11-26

### Updated
Expand Down
62 changes: 42 additions & 20 deletions mgt/models/routing_transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,34 @@ def get_device():
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


defaults = {
'max_sequence_length': 2048,
'learning_rate': 1e-4,
'dropout': 0.1,
'dim': 512,
'depth': 12,
'heads': 6,
'window_size': 128,
'reversible': True
}


def get_or_default(dictionary: dict, key: str):
return dictionary[key] if key in dictionary else defaults[key]


class RoutingTransformerModel(object):

def __init__(self,
dictionary: Dictionary,
max_sequence_length=4096,
learning_rate=1e-4,
dropout=0.1,
dim=512,
depth=6,
heads=4,
window_size=256
max_sequence_length=defaults['max_sequence_length'],
learning_rate=defaults['learning_rate'],
dropout=defaults['dropout'],
dim=defaults['dim'],
depth=defaults['depth'],
heads=defaults['heads'],
window_size=defaults['window_size'],
reversible=defaults['reversible']
):
self.dictionary = dictionary
self.learning_rate = learning_rate
Expand All @@ -54,6 +71,7 @@ def __init__(self,
self.depth = depth
self.heads = heads
self.window_size = window_size
self.reversible = reversible
self.model = self.create_model()
self.optimizer = self.create_optimizer()

Expand Down Expand Up @@ -103,7 +121,8 @@ def train(self,
epoch_losses.append(loss_item)

if nr_of_batches_processed % report_per_x_batches == 0:
print(f"Processed {nr_of_batches_processed} / {batches_per_epoch} with loss {np.mean(batch_losses)}.")
print(
f"Processed {nr_of_batches_processed} / {batches_per_epoch} with loss {np.mean(batch_losses)}.")
batch_losses = []

epoch_loss = np.mean(epoch_losses)
Expand Down Expand Up @@ -135,7 +154,8 @@ def create_model(self):
max_seq_len=self.max_sequence_length,
attn_dropout=self.dropout,
ff_dropout=self.dropout,
causal=True
causal=True,
reversible=self.reversible
)

model = AutoregressiveWrapper(model,
Expand All @@ -159,25 +179,27 @@ def save_checkpoint(self, path):
'depth': self.depth,
'window_size': self.window_size,
'heads': self.heads,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'reversible': self.reversible,
'model_state_dict': self.model.state_dict()
}, path)

@staticmethod
def load_checkpoint(path) -> RoutingTransformerModel:
checkpoint = torch.load(path)

model = RoutingTransformerModel(
dictionary=checkpoint['dictionary'],
max_sequence_length=checkpoint['max_sequence_length'],
learning_rate=checkpoint['learning_rate'],
dropout=checkpoint['dropout'],
dim=checkpoint['dim'],
depth=checkpoint['depth'],
window_size=checkpoint['window_size'],
heads=checkpoint['heads']
max_sequence_length=get_or_default(checkpoint, 'max_sequence_length'),
learning_rate=get_or_default(checkpoint, 'learning_rate'),
dropout=get_or_default(checkpoint, 'dropout'),
dim=get_or_default(checkpoint, 'dim'),
depth=get_or_default(checkpoint, 'depth'),
window_size=get_or_default(checkpoint, 'window_size'),
heads=get_or_default(checkpoint, 'heads'),
reversible=get_or_default(checkpoint, 'reversible')
)

model.model.load_state_dict(checkpoint['model_state_dict'])
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.create_optimizer()

return model
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_namespace_packages

setup(name='music-generation-toolbox',
version='0.3.1',
version='0.3.2',
description='Toolbox for generating music',
author='Vincent Bons',
url='https://github.com/wingedsheep/music-generation-toolbox',
Expand Down

0 comments on commit ec935cb

Please sign in to comment.