diff --git a/evaluate.py b/evaluate.py index 5c3f502..e095ef9 100644 --- a/evaluate.py +++ b/evaluate.py @@ -257,7 +257,7 @@ def main_worker(): model = net.InpaintGenerator().to(device) model_path = args.ckpt data = torch.load(args.ckpt, map_location=device) - model.load_state_dict(data['netG']) + model.load_state_dict(data) print('loading from: {}'.format(args.ckpt)) model.eval()