-
Notifications
You must be signed in to change notification settings - Fork 0
/
phase2_train.py
83 lines (64 loc) · 2.55 KB
/
phase2_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from models.vggnet import vggnet16
from models.resnet import resnet34
from models.fpn_resnet import resnet_fpn
from models.fpn_resnet2 import res_fpn_adv_decoder
### IMPORTANT ###
LOAD_PAST_MODEL = True
### Training parameters ###
EPOCHS = 5
BATCH_SIZE = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)
MODEL_PATH = 'resnet_fpn_adv_decoder_animated_or_real_or_anime_15+midway+10.pth'
DATA_TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Resize((224,224))])
def train_model():
### Creating the model ###
print('CREATING MODEL')
model = res_fpn_adv_decoder(num_classes=3)
if LOAD_PAST_MODEL:
model.load_state_dict(torch.load(MODEL_PATH))
model.to(DEVICE)
### Training ###
print('GETTING DATA')
train_data = ImageFolder('./Dataset/Train', transform=DATA_TRANSFORM)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
print('TRAINING MODEL')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(EPOCHS):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
# zero the parameter gradients
optimizer.zero_grad()
# get the inputs; data is a list of [inputs, labels]
inputs, img_class = data
inputs = inputs.to(DEVICE)
img_class = img_class.to(DEVICE)
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, img_class)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# print statistics
running_loss += loss.item()
#if i % 100 == 99: # print every 100 mini-batches
print('[%d, %5d] loss: %.15f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('SAVING MODEL')
if epoch%2 == 0:
torch.save(model.state_dict(), f'{MODEL_PATH[:-4]}+midway+{epoch}.pth')
if epoch%2 == 0:
torch.save(model.state_dict(), MODEL_PATH)
### Save the model ###
print('SAVING MODEL')
torch.save(model.state_dict(), MODEL_PATH)
if __name__ == '__main__':
train_model()