Skip to content

Commit

Permalink
Update simulation tensorflow example for Flower 1.0 (adap#1333)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
tanertopal and danieljanes committed Jul 27, 2022
1 parent b850b50 commit 58503d4
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
11 changes: 4 additions & 7 deletions examples/simulation_tensorflow/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
[build-system]
requires = [
"poetry==1.1.14",
]
build-backend = "poetry.masonry.api"
requires = ["poetry_core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "simulation_tensorflow"
Expand All @@ -11,7 +9,6 @@ description = "Federated Learning Simulation Quickstart with Flower"
authors = ["The Flower Authors <[email protected]>"]

[tool.poetry.dependencies]
python = "^3.7"
flwr = { extras = ["simulation"], version = "^0.17.0" }
# flwr = { extras = ["simulation"], path = "../../", develop = true } # Development
python = ">=3.7,<3.11"
flwr = { extras = ["simulation"], version = "^1.0.0rc" }
tensorflow-cpu = "^2.9.1"
29 changes: 18 additions & 11 deletions examples/simulation_tensorflow/sim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
},
"outputs": [],
"source": [
"!pip install -U flwr[\"simulation\"]"
"!pip install -U --pre flwr[\"simulation\"] tensorflow"
]
},
{
Expand Down Expand Up @@ -69,7 +69,7 @@
"- `fit`: Receive model parameters from the server, train the model parameters on the local data, and return the (updated) model parameters to the server \n",
"- `evaluate`: Received model parameters from the server, evaluate the model parameters on the local data, and return the evaluation result to the server\n",
"\n",
"We mentioned that our clients will use TensorFlow/Keras for the model training and evaluation. Keras models provide methods that make the implementation staightforward: we can update the local model with server-provides parameters through `model.set_weights`, we can train/evaluate the model through `fit/evaluate`, and we can get the updated model parameters through `model.get_weights`.\n",
"We mentioned that our clients will use TensorFlow/Keras for the model training and evaluation. Keras models provide methods that make the implementation straightforward: we can update the local model with server-provides parameters through `model.set_weights`, we can train/evaluate the model through `fit/evaluate`, and we can get the updated model parameters through `model.get_weights`.\n",
"\n",
"Let's see a simple implementation:"
]
Expand All @@ -88,7 +88,7 @@
" self.x_train, self.y_train = x_train, y_train\n",
" self.x_val, self.y_val = x_val, y_val\n",
"\n",
" def get_parameters(self):\n",
" def get_parameters(self, config):\n",
" return self.model.get_weights()\n",
"\n",
" def fit(self, parameters, config):\n",
Expand All @@ -106,11 +106,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Our class `FlowerClient` defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through `fit` and `evaluate`. Each instance of `FlowerClient` represents a *single client* in our federated learning system. Federated learning systems have multiple clients (otherwise there's not much to federate, is there?), so each client will be represented by its own instance of `FlowerClient`. If we have, for example, three clients in our workload, we'd have three instances of `FlowerClient`. Flower calls `FlowerClient.fit` on the respective instance when the server selects a particular client for training (and `FlowerClient.evaluate` for evaluation).\n",
"Our class `FlowerClient` defines how local training/evaluation will be performed and allows Flower to call the local training/evaluation through `fit` and `evaluate`. Each instance of `FlowerClient` represents a *single client* in our federated learning system. Federated learning systems have multiple clients (otherwise, there's not much to federate, is there?), so each client will be represented by its own instance of `FlowerClient`. If we have, for example, three clients in our workload, we'd have three instances of `FlowerClient`. Flower calls `FlowerClient.fit` on the respective instance when the server selects a particular client for training (and `FlowerClient.evaluate` for evaluation).\n",
"\n",
"In this notebook, we want to simulate a federated learning system with 100 clients on a single machine. This means that the server and all 100 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 100 clients would mean having 100 instances of `FlowerClient` im memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.\n",
"In this notebook, we want to simulate a federated learning system with 100 clients on a single machine. This means that the server and all 100 clients will live on a single machine and share resources such as CPU, GPU, and memory. Having 100 clients would mean having 100 instances of `FlowerClient` in memory. Doing this on a single machine can quickly exhaust the available memory resources, even if only a subset of these clients participates in a single round of federated learning.\n",
"\n",
"In addition to the regular capabilities where server and clients run on multiple machines, Flower therefore provides special simulation capabilities that create `FlowerClient` instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function called `client_fn` that creates a `FlowerClient` instance on demand. Flower calls `client_fn` whenever it needs an instance of one particular client to call `fit` or `evaluate` (those instances are usually discarded after use). Clients are identified by a client ID, or short `cid`. The `cid` can be used, for example, to load different local data partitions for each client:"
"In addition to the regular capabilities where server and clients run on multiple machines, Flower, therefore, provides special simulation capabilities that create `FlowerClient` instances only when they are actually necessary for training or evaluation. To enable the Flower framework to create clients when necessary, we need to implement a function called `client_fn` that creates a `FlowerClient` instance on demand. Flower calls `client_fn` whenever it needs an instance of one particular client to call `fit` or `evaluate` (those instances are usually discarded after use). Clients are identified by a client ID, or short `cid`. The `cid` can be used, for example, to load different local data partitions for each client:"
]
},
{
Expand Down Expand Up @@ -155,7 +155,7 @@
"id": "6SVawWSgO48Q"
},
"source": [
"We now have `FlowerClient` which defines client-side training and evaluation and `client_fn` which allows Flower to create `FlowerClient` instances whenever it needs to call `fit` or `evaluate` on one particular client. The last step is to start the actual simulation using `flwr.simulation.start_simulation`. \n",
"We now have `FlowerClient` which defines client-side training and evaluation, and `client_fn`, which allows Flower to create `FlowerClient` instances whenever it needs to call `fit` or `evaluate` on one particular client. The last step is to start the actual simulation using `flwr.simulation.start_simulation`. \n",
"\n",
"The function `start_simulation` accepts a number of arguments, amongst them the `client_fn` used to create `FlowerClient` instances, the number of clients to simulate `num_clients`, the number of rounds `num_rounds`, and the strategy. The strategy encapsulates the federated learning approach/algorithm, for example, *Federated Averaging* (FedAvg).\n",
"\n",
Expand Down Expand Up @@ -189,7 +189,7 @@
"fl.simulation.start_simulation(\n",
" client_fn=client_fn,\n",
" num_clients=NUM_CLIENTS,\n",
" num_rounds=5,\n",
" config=fl.server.ServerConfig(num_rounds=5),\n",
" strategy=strategy,\n",
")"
]
Expand All @@ -204,7 +204,7 @@
"\n",
"- Deploy server and clients on different machines using `start_server` and `start_client`\n",
"- Customize the server-side execution through custom strategies\n",
"- Customize the client-side exectution through `config` dictionaries"
"- Customize the client-side execution through `config` dictionaries"
]
}
],
Expand All @@ -215,11 +215,18 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.8.12 ('.venv': poetry)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.8.12"
},
"vscode": {
"interpreter": {
"hash": "17fc6998859a199880698aabe9f010df4b656b8df14144b324867d1aa519436f"
}
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions examples/simulation_tensorflow/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, model, x_train, y_train) -> None:
self.x_train, self.y_train = x_train[:split_idx], y_train[:split_idx]
self.x_val, self.y_val = x_train[split_idx:], y_train[split_idx:]

def get_parameters(self):
def get_parameters(self, config):
return self.model.get_weights()

def fit(self, parameters, config):
Expand Down Expand Up @@ -62,7 +62,7 @@ def main() -> None:
client_fn=client_fn,
num_clients=NUM_CLIENTS,
client_resources={"num_cpus": 4},
num_rounds=5,
config=fl.server.ServerConfig(num_rounds=5),
strategy=fl.server.strategy.FedAvg(
fraction_fit=0.1,
fraction_evaluate=0.1,
Expand Down

0 comments on commit 58503d4

Please sign in to comment.