Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create experimental mt-pytorch code example #1446

Merged
merged 4 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/mt-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Multi-Tenant Federated Learning with Flower and PyTorch

This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart_pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced_pytorch)) to learn how to use Flower with PyTorch.

## Setup

```bash
./dev/venv-reset.sh
```

## Exec

```bash
python driver.py
```
38 changes: 38 additions & 0 deletions examples/mt-pytorch/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import flwr as fl

from task import (
Net,
DEVICE,
load_data,
get_parameters,
set_parameters,
train,
test,
)


# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()

# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return get_parameters(net)

def fit(self, parameters, config):
set_parameters(parameters)
train(net, trainloader, epochs=1)
return get_parameters(), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}


# Start Flower client
fl.client.start_numpy_client(
server_address="[::]:9091",
client=FlowerClient(),
)
86 changes: 86 additions & 0 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List
import random

from flwr.driver import (
Driver,
GetClientsResponse,
GetClientsRequest,
Task,
Result,
CreateTasksRequest,
CreateTasksResponse,
GetResultsRequest,
GetResultsResponse,
TaskAssignment,
)
from flwr.common import ServerMessage, FitIns, ndarrays_to_parameters

from task import Net, get_parameters, set_parameters

# -------------------------------------------------------------------------- Driver SDK
driver = Driver(driver_service_address="[::]:9091", certificates=None)
# -------------------------------------------------------------------------- Driver SDK

parameters = ndarrays_to_parameters(get_parameters(net=Net()))
num_rounds = 3

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
# -------------------------------------------------------------------------- Driver SDK

for server_round in range(num_rounds):
print(f"Commencing server round {server_round + 1}")

# Get a list of client ID's from the server
get_clients_req = GetClientsRequest()

# ---------------------------------------------------------------------- Driver SDK
get_clients_res: GetClientsResponse = driver.get_clients(req=get_clients_req)
# ---------------------------------------------------------------------- Driver SDK

# Sample three clients
all_client_ids: List[int] = get_clients_res.client_ids
print(f"Got {len(all_client_ids)} client IDs")
sampled_client_ids: List[int] = random.sample(all_client_ids, 3)
print(f"Sampled {len(sampled_client_ids)} client IDs")

# Schedule a task for all three clients
fit_ins: FitIns = FitIns(parameters=parameters, config={})
task = Task(task_id=123, legacy_server_message=ServerMessage(fit_ins=fit_ins))
tanertopal marked this conversation as resolved.
Show resolved Hide resolved
task_assignment: TaskAssignment = TaskAssignment(task=task, client_ids=sampled_client_ids)
create_tasks_req = CreateTasksRequest(task_assignments=[task_assignment])

# ---------------------------------------------------------------------- Driver SDK
create_tasks_res: CreateTasksResponse = driver.create_tasks(req=create_tasks_req)
# ---------------------------------------------------------------------- Driver SDK

print(f"Scheduled {len(create_tasks_res.task_ids)} tasks")

# Wait for results
task_ids: List[int] = create_tasks_res.task_ids
all_results: List[Result] = []
while True:
get_results_req = GetResultsRequest(task_ids=create_tasks_res.task_ids)

# ------------------------------------------------------------------ Driver SDK
get_results_res: GetResultsResponse = driver.get_results(
req=get_results_req
)
# ------------------------------------------------------------------ Driver SDK

results: List[Result] = get_results_res.results
print(f"Got {len(get_results_res.results)} results")

all_results += results
if len(all_results) == len(task_ids):
break
Comment on lines +62 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danieljanes We don't have an early stop here but as this is all experimental I am just noting my thoughts here :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, also would be cool to use tqdm here too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that these things would be cool, but my goal was more to create a minimal example here, which is why I'd suggest to keep is focussed :) regarding the early stopping, my assumption was that the server would enforce a timeout on all tasks anyways, so the driver can just wait until it gets either a full result or something that tells it that a particular task did not finish in time.


# "Aggregate" results
client_messages = [result.legacy_client_message for result in all_results]
print(f"Received {len(client_messages)} results")

# Repeat

# -------------------------------------------------------------------------- Driver SDK
driver.disconnect()
# -------------------------------------------------------------------------- Driver SDK
16 changes: 16 additions & 0 deletions examples/mt-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[build-system]
requires = ["poetry_core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "mt-pytorch"
version = "0.1.0"
description = "Multi-Tenant Federated Learning with Flower and PyTorch"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.7"
flwr = { path = "../../", develop = true, extras = ["simulation"] }
torch = "^1.12.0"
torchvision = "^0.13.0"
tqdm = "^4.63.0"
25 changes: 25 additions & 0 deletions examples/mt-pytorch/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# Multiply accuracy of each client by number of examples used
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]

# Aggregate and return custom metric (weighted average)
return {"accuracy": sum(accuracies) / sum(examples)}


# Define strategy
strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)
79 changes: 79 additions & 0 deletions examples/mt-pytorch/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import warnings
from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm


warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in tqdm(trainloader):
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()


def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
with torch.no_grad():
for images, labels in tqdm(testloader):
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
return loss / len(testloader.dataset), correct / total


def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)


def get_parameters(net):
return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)