Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make examples use start_client(). #2718

Merged
merged 15 commits into from
Jan 25, 2024
Prev Previous commit
Next Next commit
minimal fix opacus sim
  • Loading branch information
jafermarq committed Dec 19, 2023
commit b65f52454701b9356e2674f2264c9bd9a4585f1a
14 changes: 7 additions & 7 deletions examples/opacus/dp_cifar_simulation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import math
from collections import OrderedDict
from typing import Callable, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torchvision.transforms as transforms
from opacus.dp_model_inspector import DPModelInspector
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from flwr.common.typing import Scalar

from dp_cifar_main import DEVICE, PARAMS, DPCifarClient, Net, test

Expand All @@ -23,8 +23,6 @@ def client_fn(cid: str) -> fl.client.Client:
# Load model.
model = Net()
# Check model is compatible with Opacus.
# inspector = DPModelInspector()
# print(f"Is the model valid? {inspector.validate(model)}")

# Load data partition (divide CIFAR10 into NUM_CLIENTS distinct partitions, using 30% for validation).
transform = transforms.Compose(
Expand All @@ -50,7 +48,9 @@ def client_fn(cid: str) -> fl.client.Client:

# Define an evaluation function for centralized evaluation (using whole CIFAR10 testset).
def get_evaluate_fn() -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]:
def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]:
def evaluate(
server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
):
transform = transforms.Compose(
[
transforms.ToTensor(),
Expand All @@ -63,7 +63,7 @@ def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]:
state_dict = OrderedDict(
{
k: torch.tensor(np.atleast_1d(v))
for k, v in zip(model.state_dict().keys(), weights)
for k, v in zip(model.state_dict().keys(), parameters)
}
)
model.load_state_dict(state_dict, strict=True)
Expand All @@ -82,7 +82,7 @@ def main() -> None:
client_fn=client_fn,
num_clients=NUM_CLIENTS,
client_resources={"num_cpus": 1},
num_rounds=3,
config=fl.server.ServerConfig(num_rounds=3),
strategy=fl.server.strategy.FedAvg(
fraction_fit=0.1, fraction_evaluate=0.1, evaluate_fn=get_evaluate_fn()
),
Expand Down
Loading