Skip to content

Commit

Permalink
Lightning example (adap#1344)
Browse files Browse the repository at this point in the history
  • Loading branch information
edogab33 committed Aug 1, 2022
1 parent 0eafca6 commit 1c2fc0b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
13 changes: 3 additions & 10 deletions examples/quickstart_pytorch_lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@
import mnist
import pytorch_lightning as pl
from collections import OrderedDict


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST


class FlowerClient(fl.client.NumPyClient):
Expand All @@ -19,7 +12,7 @@ def __init__(self, model, train_loader, val_loader, test_loader):
self.val_loader = val_loader
self.test_loader = test_loader

def get_parameters(self):
def get_parameters(self, config):
encoder_params = _get_parameters(self.model.encoder)
decoder_params = _get_parameters(self.model.decoder)
return encoder_params + decoder_params
Expand All @@ -34,7 +27,7 @@ def fit(self, parameters, config):
trainer = pl.Trainer(max_epochs=1, progress_bar_refresh_rate=0)
trainer.fit(self.model, self.train_loader, self.val_loader)

return self.get_parameters(), 55000, {}
return self.get_parameters(config={}), 55000, {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
Expand Down Expand Up @@ -63,7 +56,7 @@ def main() -> None:

# Flower client
client = FlowerClient(model, train_loader, val_loader, test_loader)
fl.client.start_numpy_client("127.0.0.1:8080", client)
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart_pytorch_lightning/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.7"
flwr = "^0.17.0"
flwr = "^1.0.0"
# flwr = { path = "../../", develop = true } # Development
pytorch-lightning = "^1.4.7"
torchvision = "^0.10.0"
2 changes: 1 addition & 1 deletion examples/quickstart_pytorch_lightning/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main() -> None:
# Start Flower server for three rounds of federated learning
fl.server.start_server(
server_address="0.0.0.0:8080",
config={"num_rounds": 10},
config=fl.server.ServerConfig(num_rounds=10),
strategy=strategy,
)

Expand Down

0 comments on commit 1c2fc0b

Please sign in to comment.