Skip to content

Commit

Permalink
Use YAML files to manage train_segment_w_context.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ceshine committed Oct 27, 2019
1 parent 6e9bd11 commit e44778a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 47 deletions.
17 changes: 9 additions & 8 deletions yt8m/train_pure_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ def prepare_models(config, state_dict=None):
segment_model.load_state_dict(state_dict)
if isinstance(segment_model, SampleFrameModelWrapper):
segment_model = segment_model.model
return SegmentModelWrapper(segment_model), config
return SegmentModelWrapper(segment_model)


@telegram_sender(token=BOT_TOKEN, chat_id=CHAT_ID, name="Training on Segment")
def main():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('base_model_dir', type=str)
arg('config', type=str)
arg('base_model_dir', type=str)
arg('--steps', type=int, default=-1)
arg('--fold', type=int, default=0)
arg('--name', type=str, default="model")
Expand All @@ -163,30 +163,30 @@ def main():
video_config = yaml.load(fin)
config.update(video_config)
state_dict = torch.load(str(base_model_dir / "model.pth"))
model, video_config = prepare_models(config, state_dict=state_dict)
model = prepare_models(config, state_dict=state_dict)

print(model)
lr = float(training_config["lr"])
optimizer_grouped_parameters = [
{
'params': [p for n, p in model.named_parameters()
if not any(nd in n for nd in NO_DECAY)],
'lr': args.lr
'lr': lr
},
{
'params': [p for n, p in model.named_parameters()
if any(nd in n for nd in NO_DECAY)],
'lr': args.lr
'lr': lr
}
]
optimizer = WeightDecayOptimizerWrapper(
torch.optim.Adam(
optimizer_grouped_parameters,
lr=float(training_config["lr"]),
eps=float(training_config["eps"])),
lr=lr, eps=float(training_config["eps"])),
[training_config["weight_decay"], 0]
)
# optimizer = torch.optim.Adam(
# optimizer_grouped_parameters, lr=args.lr, eps=1e-7)
# optimizer_grouped_parameters, lr=lr, eps=1e-7)

n_steps = training_config["steps"]
checkpoints = CheckpointCallback(
Expand Down Expand Up @@ -228,6 +228,7 @@ def main():
bot.load_model(checkpoints.best_performers[0][1])
checkpoints.remove_checkpoints(keep=0)

# save the model
target_dir = (MODEL_DIR /
f"{args.name}_{args.fold}_{datetime.now().strftime('%Y%m%d-%H%M')}")
target_dir.mkdir(parents=True)
Expand Down
109 changes: 70 additions & 39 deletions yt8m/train_segment_w_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
# from torch.optim.lr_scheduler import CosineAnnealingLR
import yaml
import numpy as np
from helperbot import (
WeightDecayOptimizerWrapper,
Expand All @@ -23,9 +24,9 @@
from .loss import SampledCrossEntropyLoss
from .telegram_tokens import BOT_TOKEN, CHAT_ID
from .telegram_sender import telegram_sender
from .train_pure_segment import (
YoutubeBot, get_loaders, patch
)
from .train_video import create_video_model
from .train_pure_segment import YoutubeBot, get_loaders


CACHE_DIR = Path('./data/cache/segment/')
CACHE_DIR.mkdir(exist_ok=True, parents=True)
Expand All @@ -35,12 +36,17 @@
NO_DECAY = ['bias', 'LayerNorm.weight', 'BatchNorm.weight']


def prepare_models(args):
model_dir = Path(args.model_dir)
context_model = patch(torch.load(str(model_dir / args.context_model)))
def prepare_models(config, *, context_state_dict, segment_state_dict):
# Restore the video model for the context encoder
context_model = create_video_model(config["context_base"]["model"])
if context_state_dict is not None:
context_model.load_state_dict(context_state_dict)
if isinstance(context_model, SampleFrameModelWrapper):
context_model = context_model.model
segment_model = patch(torch.load(str(model_dir / args.segment_model)))
# Restore the video model for the segment encoder
segment_model = create_video_model(config["segment_base"]["model"])
if segment_state_dict is not None:
segment_model.load_state_dict(segment_state_dict)
if isinstance(segment_model, SampleFrameModelWrapper):
segment_model = segment_model.model
if isinstance(segment_model, NeXtVLADModel):
Expand All @@ -63,82 +69,102 @@ def prepare_models(args):
context_model = GatedDBofContextEncoder(context_model)
else:
raise ValueError("Model not supported yet!")
model_config = config["segment_w_context"]["model"]
return ContextualSegmentModel(
context_model, segment_model, context_dim, segment_dim,
args.fcn_dim, args.drop, se_reduction=args.se_reduction,
max_video_len=args.max_len, train_context=args.finetune_context,
num_mixtures=2
model_config["fcn_dim"], model_config["p_drop"],
se_reduction=model_config["se_reduction"],
max_video_len=model_config["max_len"],
train_context=model_config["finetune_context"],
num_mixtures=model_config["n_mixture"]
).cuda()


@telegram_sender(token=BOT_TOKEN, chat_id=CHAT_ID, name="Training on Segment")
def main():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('model_dir', type=str)
arg('context_model', type=str)
arg('segment_model', type=str)
arg('--batch-size', type=int, default=32)
arg('--lr', type=float, default=3e-4)
arg('--steps', type=int, default=30000)
arg('--offset', type=int, default=0)
arg('--ckpt-interval', type=int, default=4000)
arg('config')
arg('context_model_dir', type=str)
arg('segment_model_dir', type=str)
arg('--steps', type=int, default=-1)
arg('--fold', type=int, default=0)
arg('--drop', type=float, default=0.5)
arg('--fcn-dim', type=int, default=512)
arg('--max-len', type=int, default=-1)
arg('--se-reduction', type=int, default=0)
arg('--finetune-context', action="store_true")
arg('--name', type=str, default="model")
arg('--name', type=str, default="context_model")
args = parser.parse_args()

with open(args.config) as fin:
config = yaml.load(fin)
training_config = config["segment_w_context"]["training"]
train_loader, valid_loader = get_loaders(
args, seed=int(os.environ.get("SEED", "9293")), offset=args.offset)
training_config["batch_size"], fold=args.fold,
seed=int(os.environ.get("SEED", "9293")),
offset=training_config["offset"])

model = prepare_models(args)
if args.steps > 0:
# override
training_config["steps"] = args.steps

context_model_dir = Path(args.context_model_dir)
with open(context_model_dir / "config.yaml") as fin:
context_config = yaml.load(fin)
config["context_base"] = context_config["video"]
context_state_dict = torch.load(str(context_model_dir / "model.pth"))
segment_model_dir = Path(args.segment_model_dir)
with open(segment_model_dir / "config.yaml") as fin:
segment_config = yaml.load(fin)
config["segment_base"] = segment_config["video"]
segment_state_dict = torch.load(str(segment_model_dir / "model.pth"))
model = prepare_models(
config,
context_state_dict=context_state_dict,
segment_state_dict=segment_state_dict)
print(model)

# optimizer_grouped_parameters = []
lr = float(training_config["lr"])
optimizer_grouped_parameters = [
{
'params': [p for n, p in model.segment_model.named_parameters()
if not any(nd in n for nd in NO_DECAY)],
'lr': args.lr / 2
'lr': lr / 2
},
{
'params': [p for n, p in model.segment_model.named_parameters()
if any(nd in n for nd in NO_DECAY)],
'lr': args.lr / 2
'lr': lr / 2
}
]
if args.finetune_context:
if config["segment_w_context"]["model"]["finetune_context"]:
optimizer_grouped_parameters += [
{
'params': [p for n, p in model.context_model.named_parameters()
if not any(nd in n for nd in NO_DECAY)],
'lr': args.lr / 4
'lr': lr / 4
},
{
'params': [p for n, p in model.context_model.named_parameters()
if any(nd in n for nd in NO_DECAY)],
'lr': args.lr / 4
'lr': lr / 4
}
]
for module in (model.expert_fc, model.gating_fc, model.intermediate_fc):
optimizer_grouped_parameters += [
{
'params': [p for n, p in module.named_parameters()
if not any(nd in n for nd in NO_DECAY)],
'lr': args.lr
'lr': lr
},
{
'params': [p for n, p in module.named_parameters()
if any(nd in n for nd in NO_DECAY)],
'lr': args.lr
'lr': lr
}
]
optimizer = WeightDecayOptimizerWrapper(
torch.optim.Adam(optimizer_grouped_parameters, lr=args.lr),
[0.02, 0] * (len(optimizer_grouped_parameters) // 2)
torch.optim.Adam(
optimizer_grouped_parameters,
lr=lr, eps=float(training_config["eps"])),
[training_config["weight_decay"], 0] * (
len(optimizer_grouped_parameters) // 2)
)

n_steps = args.steps
Expand Down Expand Up @@ -176,15 +202,20 @@ def main():
pbar=True, use_tensorboard=False
)
bot.train(
total_steps=n_steps, checkpoint_interval=args.ckpt_interval
total_steps=n_steps, checkpoint_interval=training_config["ckpt_interval"]
)
bot.load_model(checkpoints.best_performers[0][1])
checkpoints.remove_checkpoints(keep=0)

# save the model
target_dir = (MODEL_DIR /
f"{args.name}_{args.fold}_{datetime.now().strftime('%Y%m%d-%H%M')}")
target_dir.mkdir(parents=True)
torch.save(
bot.model, MODEL_DIR /
f"{args.name}_{args.fold}_{datetime.now().strftime('%Y%m%d-%H%M')}.pth"
bot.model.state_dict(), target_dir / "model.pth"
)
with open(target_dir / "config.yaml", "w") as fout:
fout.write(yaml.dump(config, default_flow_style=False))


if __name__ == "__main__":
Expand Down

0 comments on commit e44778a

Please sign in to comment.