Skip to content

Commit

Permalink
Fix loss computation for PyTorch (#1685)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll committed Feb 20, 2023
1 parent 38c208b commit 5718e98
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 30 deletions.
9 changes: 2 additions & 7 deletions examples/advanced_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,18 @@ def test(net, testloader, steps: int = None, device: str = "cpu"):
print("Starting evalutation...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for batch_idx, (images, labels) in enumerate(testloader):
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if steps is not None and batch_idx == steps:
break
if steps is None:
loss /= len(testloader.dataset)
else:
loss /= total
accuracy = correct / total
accuracy = correct / len(testloader.dataset)
net.to("cpu") # move model back to CPU
return loss, accuracy

Expand Down
7 changes: 2 additions & 5 deletions examples/embedded_devices/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,13 @@ def test(
) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
criterion = nn.CrossEntropyLoss()
correct = 0
total = 0
loss = 0.0
correct, loss = 0, 0.0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
accuracy = correct / len(testloader.dataset)
return loss, accuracy
6 changes: 3 additions & 3 deletions examples/mt-pytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ def train(net, trainloader, epochs):
def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in tqdm(testloader):
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
return loss / len(testloader.dataset), correct / total
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def load_data():
Expand Down
5 changes: 2 additions & 3 deletions examples/opacus/dp_cifar_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,15 @@ def train(net, trainloader, privacy_engine, epochs):

def test(net, testloader):
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
correct, loss = 0, 0.0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
accuracy = correct / len(testloader.dataset)
return loss, accuracy


Expand Down
7 changes: 2 additions & 5 deletions examples/pytorch_from_centralized_to_federated/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def test(
"""Validate the network on the entire test set."""
# Define loss and metrics
criterion = nn.CrossEntropyLoss()
correct = 0
total = 0
loss = 0.0
correct, loss = 0, 0.0

# Evaluate the network
net.to(device)
Expand All @@ -126,9 +124,8 @@ def test(
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1) # pylint: disable=no-member
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
accuracy = correct / len(testloader.dataset)
return loss, accuracy


Expand Down
8 changes: 4 additions & 4 deletions examples/quickstart_pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss(reduction="sum")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in tqdm(trainloader):
Expand All @@ -53,16 +53,16 @@ def train(net, trainloader, epochs):

def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss(reduction="sum")
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in tqdm(testloader):
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
total = len(testloader.dataset)
return loss / total, correct / total
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def load_data():
Expand Down
5 changes: 2 additions & 3 deletions examples/simulation_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@ def train(net, trainloader, epochs, device: str):
def test(net, testloader, device: str):
"""Validate the network on the entire test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
accuracy = correct / len(testloader.dataset)
return loss, accuracy

0 comments on commit 5718e98

Please sign in to comment.