forked from adap/flower
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.py
41 lines (29 loc) · 1.08 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from flwr.client import ClientApp, NumPyClient
from task import DEVICE, Net, get_weights, load_data, set_weights, test, train
# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()
# Define FlowerClient and client_fn
class FlowerClient(NumPyClient):
def fit(self, parameters, config):
set_weights(net, parameters)
results = train(net, trainloader, testloader, epochs=1, device=DEVICE)
return get_weights(net), len(trainloader.dataset), results
def evaluate(self, parameters, config):
set_weights(net, parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}
def client_fn(cid: str):
"""Create and return an instance of Flower `Client`."""
return FlowerClient().to_client()
# Flower ClientApp
app = ClientApp(
client_fn=client_fn,
)
# Legacy mode
if __name__ == "__main__":
from flwr.client import start_client
start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)