Federated Learning (FL) is an emerging approach to machine learning (ML) where model training data is not stored in a central location. During ML training, we typically need to access the entire training dataset on a single machine. For purposes of performance scaling, we divide the training data between multiple CPUs, multiple GPUs, or a cluster of machines. But in actuality, the training data is available in a single place. This poses challenges when we gather data from many distributed devices, like mobile phones or industrial sensors. For privacy reasons, we may not be able to collect training data from mobile phones. In an industrial setting, we may not have the network connectivity to push a large volume of sensor data to a central location. In addition, raw sensor data may contain sensitive information about a plant’s operation.
In this blog, we will demonstrate a working example of federated learning on AWS. We’ll also discuss some design challenges you may face when implementing FL for your own use case.
Federated learning challenges
In FL, all data stays on the devices. A training coordinator process invokes a training round or epoch, and the devices update a local copy of the model using local data. When complete, the devices each send their updated copy of the model weights to the coordinator. The coordinator combines these weights using an approach like federated averaging, and sends the new weights to each device. Figures 1 and 2 compare traditional ML with federated learning.
Traditional ML training:
Compared to traditional ML, federated learning poses several challenges.
Unpredictable training participants. Mobile and industrial devices may have intermittent connectivity. We may have thousands of devices, and the communication with these devices may be asynchronous. The training coordinator needs a way to discover and communicate with these devices.
Asynchronous communication. As previously noted, devices like IoT sensors may use communication protocols like MQTT, and rely on IoT frameworks with asynchronous pub/sub messaging. For example, the training coordinator cannot simply make a synchronous call to a device to send updated weights.
Passing large messages. The model weights can be several MB in size, making them too large for typical MQTT payloads. This can be problematic when network connections are not high bandwidth.
Emerging frameworks. Deep learning frameworks like TensorFlow and PyTorch do not yet fully support FL. Each has emerging options for edge-oriented ML, but they are not yet production-ready, and are intended mostly for simulation work.
Solving for federated learning challenges
Let’s begin by discussing the framework used for FL. This framework should work with any of the major deep learning systems like PyTorch and TensorFlow. It should clearly separate what happens on the training coordinator versus what happens on the devices. The Flower framework satisfies these requirements. In the simplest cases, Flower only requires that a device or client implement four methods: get model weights, set model weights, run a training round, and return metrics (like accuracy). Flower provides a default weight combination strategy, federated averaging, although we could design our own.
Next, let’s consider the challenges of devices and coordinator-to-device communication. Here it helps to have a specific use case in mind. Let’s focus on an industrial use case using the AWS IoT services. The IoT services will provide device discovery and asynchronous messaging tools. Running Flower code in tasks running on Amazon ECS on AWS Fargate containers orchestrated by an AWS Step Functions workflow, bridges the gap between the synchronous training process and the asynchronous device communication.
Devices: Our devices are registered with the AWS IoT Core, which gives access to MQTT communication. We use an AWS IoT Greengrass device, which lets us push AWS Lambda functions on the core device to run the actual ML training code. Greengrass also provides access to other AWS services and provides a streaming capability for pushing large payloads to the AWS Lambdacloud. Devices publish heartbeat messages to an MQTT topic to facilitate device discovery by the training coordinator.
Flower processes: The Flower framework requires a server and one or more clients. The clients communicate with the server synchronously over gRPC, so we cannot directly run the Flower client code on the devices. Rather, we use Amazon ECS Fargate containers to run the Flower server and client code. The client container acts as a proxy and communicates with the devices using MQTT messaging, Greengrass streaming, and direct Amazon S3 file transfer. Since the devices cannot send responses directly to the proxy containers, we use IoT rules to push responses to an Amazon DynamoDB bookkeeping table.
Orchestration: We use a Step Functions workflow to launch a training run. It performs device discovery, launches the server container, and then launches one proxy client container for each Greengrass core.
Metrics: The Flower server and proxy containers emit useful data to logs. We use Amazon CloudWatch log metric filters to collect this information on a CloudWatch dashboard.
Figure 3 shows the high-level design of the prototype.
As you move into production with FL, you must design for a new set of challenges compared to traditional ML.
What type of devices do you have? Mobile and IoT devices are the most common candidates for FL.
How do the devices communicate with the coordinator? Devices come and go, and don’t always support synchronous communication. How will you discover devices that are available for FL and manage asynchronous communication? IoT frameworks are designed to work with large fleets of devices with intermittent connectivity. But mobile devices will find it easier to push results back to a training coordinator using regular API endpoints.
How will you integrate FL into your ML practice? As you saw in this example, we built a working FL prototype using AWS IoT, container, database, and application integration services. A data science team is likely to work in Amazon SageMaker or an equivalent ML environment that provides model registries, experiment tracking, and other features. How will you bridge this gap? (We may need to invent a new term – “MLEdgeOps.”)
In this blog, we gave a working example of federated learning in an IoT scenario using the Flower framework. We discussed the challenges involved in FL compared to traditional ML when building your own FL solution. As FL is an important and emerging topic in edge ML scenarios, we invite you to try our GitHub sample code. Please give us feedback as GitHub issues, particularly if you see additional federated learning challenges that we haven’t covered yet.