Federated Learning: Training Models Where the Data Lives

Imagine a group of hospitals trying to train a disease-risk model together.
Each hospital has valuable patient records, but nobody is allowed (or willing) to centralize them.
Federated learning solves this dilemma by moving training to the data: clients compute updates locally and share only model information.

Federated Learning (FL) decentralizes the conventional training of ML models by enabling multiple clients (phones, hospitals, banks, sensors) to collaboratively learn a shared model while keeping their data local.
Clients compute local updates, and a coordinating server aggregates them into a new global model.

(Image source)

Given the increasing importance of data privacy and the massive amounts of data generated on personal devices, the significance of FL in today’s data-centric world cannot be overstated.

Why Federated Learning exists (and what it is not)

Centralized machine learning typically assumes one of these workflows:

  1. collect data from sources into a central store,
  2. train a model on the pooled dataset,
  3. deploy the trained model.

This can be difficult when data is sensitive (health, finance), regulated (GDPR, CCPA), or expensive to move (network and storage costs).

FL is a family of distributed optimization methods that keep data local and communicate model weights, gradients, or compressed updates.

Two common settings have different assumptions:

  • Cross-device FL: millions of unreliable, intermittently available devices (phones). Communication and partial participation dominate.
  • Cross-silo FL: tens to hundreds of reliable organizations (hospitals, banks). Governance, domain shift, and different pipelines dominate.

It helps to separate three related ideas:

  • Federated learning: distributed optimization under data-locality constraints.
  • Privacy: an additional property (often via secure aggregation and/or differential privacy).
  • Personalization: optional methods that adapt a global model to each client.

FL does not automatically guarantee strong privacy.
The server sees (some form of) updates, and updates can leak information unless protections are used.

If you want a crisp mental model: FL is “distributed training without centralizing data”, not “private training by default”.

The canonical Federated Learning loop

federated-learning-steps

In one federated round $t$:

  1. Initialization: the server initializes a global model (or loads a checkpoint).
  2. Client sampling: the server selects a subset $S_t$ of available clients.
  3. Broadcast: the server sends current weights $\theta^t$ to the selected clients.
  4. Local training: each client runs a few local steps/epochs of SGD on its own data.
  5. Upload: clients send updated weights $\theta_k^{t+1}$ or updates $\Delta_k^t$ (often compressed), not raw data.
  6. Aggregation: the server combines updates into the next global model $\theta^{t+1}$.

Technical details : objective and FedAvg

Assume $C$ total clients and $K$ clients (i.e. $|S_t| = K$) selected per round.
Client $k$ has dataset $\mathcal{D}_k$ of size $n_k$.

The global empirical risk minimization objective across all clients is

$$
\min_{\theta} F(\theta) = \sum_{k=1}^{C} \frac{n_k}{\sum_{j=1}^{C} n_j} F_k(\theta),
\quad
F_k(\theta) = \frac{1}{n_k} \sum_{(x,y) \in \mathcal{D}_k} \ell(\theta; x, y).
$$

In practice, each round only touches a subset $S_t$.

Federated averaging (FedAvg) is “local SGD + weighted averaging”. In round $t$:

  • Each selected client $k \in S_t$ starts from $\theta^t$ and runs $E$ local epochs (or a fixed number of steps):

$$
\theta_{k}^{t+1} \leftarrow \text{LocalTrain}(\theta^t, \mathcal{D}_k).
$$

  • The server aggregates using data-size weights over the participating clients:

$$
\theta^{t+1} \leftarrow \sum_{k \in S_t} \frac{n_k}{\sum_{j \in S_t} n_j} \, \theta_{k}^{t+1}.
$$

If clients instead send updates $\Delta_k^t = \theta_{k}^{t+1} – \theta^t$, aggregation becomes

$$
\theta^{t+1} \leftarrow \theta^t + \sum_{k \in S_t} \frac{n_k}{\sum_{j \in S_t} n_j} \, \Delta_k^t.
$$

This weighting matters: averaging client models uniformly can overweight small clients and destabilize training.

What changes compared with centralized training

In centralized training, data is i.i.d. by construction (or at least heavily mixed).
In FL, data is commonly non-IID (heterogeneous): each client may represent a narrow slice of the population.

Two practical consequences dominate:

  • Client drift: local steps move toward each client’s local optimum, so updates disagree and can cancel.
  • Systems coupling: only some clients participate each round, communication is expensive, and stragglers/dropouts happen.

One useful rule of thumb: the more heterogeneous the clients (label skew, feature shift, concept drift), the fewer local epochs/steps you can safely take before the global objective starts to “wobble”.

Common algorithmic variants (when FedAvg is not enough)

FedAvg is the baseline, but many practical stacks use one of these extensions:

  • FedProx: adds a proximal term to reduce drift under heterogeneity.
  • FedOpt / server-side optimizers: treat aggregated updates as a “gradient-like” signal and apply Adam/Yogi/Adagrad on the server (often called FedAdam/FedYogi).
  • Scaffold / control variates: reduce client drift by correcting local updates.

These methods don’t eliminate heterogeneity, but they often stabilize training and reduce rounds-to-target.

How FL is evaluated (what “good” looks like)

FL results are easy to misread if you only report one global metric.
Good reports usually include:

  • Global performance: accuracy/AUC on a central or pooled-like test set (if you can construct one).
  • Per-client distribution: mean/median plus spread (e.g., quantiles) across clients.
  • Tail metrics: worst-$p\%$ client performance, or performance on minority/rare clients.
  • Communication cost: rounds-to-target and bytes uplink per client.
  • Stability: variance across rounds and sensitivity to client sampling.

For cross-silo FL, also report site effects (domain shift): performance by institution and over time.

If you can, report two baselines:

  • centralized (pooled) training (upper bound if pooling were allowed),
  • local-only training per client (lower bound that may still beat a weak global model under high heterogeneity).

What practitioners must design for

FL is as much an engineering problem as it is an optimization problem. These issues typically determine success.

Data and optimization challenges

  • Non-IID data: class imbalance, concept drift, and per-client label skew are common.
  • Unbalanced data: $n_k$ can vary by orders of magnitude.
  • Partial participation: the set $S_t$ changes each round; stragglers happen.
  • Convergence instability: too many local steps or too high a client learning rate can diverge.

Common mitigations:

  • tune local steps $E$ (or steps per round) and aggregation frequency,
  • use server-side adaptive optimizers (FedAdam / FedYogi) with cautious client learning rates,
  • add regularization to reduce drift (for example FedProx),
  • clip client updates to limit outliers (also useful for privacy),
  • evaluate per-client metrics, not only global averages.

Systems challenges

  • Communication bottlenecks: uplink is often the limiting factor.
  • Device heterogeneity: clients have different compute, memory, and energy budgets.
  • Reliability: clients can drop mid-round.

Common mitigations:

  • update compression (quantization, sparsification),
  • periodic training windows (for example, train only while charging on Wi-Fi),
  • robust orchestration to handle missing clients and stragglers.

A practical design checklist

If you’re implementing FL, you usually need to decide (explicitly) on:

  1. Client selection policy: how many clients per round, and whether to oversample rare clients.
  2. Local computation budget: steps/epochs, batch size, optimizer, and learning rate schedule.
  3. Server optimizer: plain averaging vs FedAdam/FedYogi, momentum, and update clipping.
  4. Update representation: weights vs deltas, compression, and sparsification.
  5. Privacy posture: secure aggregation? DP? threat model and acceptable leakage.
  6. Robustness: poisoning defenses, anomaly detection, and client authentication.

Privacy and security: what FL does and does not do

FL reduces raw-data exposure, but does not eliminate privacy risk.
Updates can leak information (e.g., membership or attributes), and malicious clients can poison the model.

At a high level, there are three different goals people conflate:

  • Confidentiality from the server: the server should not see individual client updates (secure aggregation).
  • Statistical privacy guarantees: limit what can be inferred about any participant (differential privacy).
  • Robustness to adversaries: tolerate malicious clients and sybils (robust aggregation + authentication).

Threats and mitigations

ThreatWhat it means in practiceTypical mitigation
Update leakage (membership or attribute inference)Updates reveal whether a record was present or leak attributesDifferential privacy (client-level or example-level), clipping + noise
Gradient inversionReconstruct inputs that produced gradientsSecure aggregation, DP, representation learning, smaller batch exposure
Model poisoningA client sends crafted updates to insert a backdoorRobust aggregation (median, trimmed mean), anomaly detection
Sybil attacksOne attacker pretends to be many clientsStrong client authentication, rate limits, reputation systems
Dropout attacks on secure aggregationMalicious coordination breaks privacy guaranteesProtocol hardening, minimum participation thresholds

Two widely used privacy tools:

  1. Secure aggregation: the server learns only an aggregate (e.g., a sum) of client updates, not each individual update.
  2. Differential privacy (DP): adds rigor to privacy claims, typically by clipping updates and adding calibrated noise.

For example, a simplified client-side DP mechanism clips the update norm and adds Gaussian noise. In practice, many deployments use secure aggregation + client-side clipping, and optionally add DP noise.

When FL is a good idea (and when it is not)

FL is often a strong fit when:

  • data is naturally distributed (devices, institutions),
  • moving data is difficult or prohibited,
  • on-device personalization or continual learning is valuable,
  • there are many clients with moderate local data.

FL is often a poor fit when:

  • data can be centrally pooled safely and cheaply,
  • model updates are larger than the data that would be moved,
  • the problem requires tight synchronization (some deep RL and sequence training setups),
  • clients are too scarce to learn a stable global signal.

One very common failure mode: if clients have different label definitions / measurement pipelines (especially in cross-silo settings), FL may optimize an inconsistent objective unless you align schemas, logging, and evaluation.

Applications (with the FL-specific nuance)

  • Next-word prediction (original FL use case): By utilizing users’ typing habits without collecting actual text data, Google could improve its language models while addressing user privacy concerns.
  • Mobile personalization: keyboard prediction, ranking, and on-device adaptation.
  • Healthcare: cross-hospital training without sharing records; careful evaluation for site effects.
  • Finance: fraud models across institutions; strong adversarial assumptions are necessary.
  • IoT and smart cities: distributed sensors; bandwidth constraints dominate design choices.
  • Autonomous driving fleets: learning from geographically diverse environments; heavy emphasis on robustness.

Future directions (high-signal areas)

  • Personalized FL: methods that explicitly model client heterogeneity (for example fine-tuning heads or meta-learning).
  • Robust FL: stronger defenses against poisoning and sybil attacks with explicit threat models.
  • Better benchmarks: realistic non-IID datasets, heterogeneous device simulators, and standardized reporting.
  • Cross-silo FL: small number of reliable institutions (hospitals); different assumptions than cross-device FL.

References

FAQ

How does FL handle non-IID data across clients?

Non-IID data is the default in FL. Common techniques include reducing local steps, using server-side adaptive optimizers, adding regularization (e.g., FedProx), and evaluating per-client (not just global) metrics.

What are the communication costs associated with FL?

Communication can dominate runtime. Typical mitigations include update compression (quantization/sparsification), reducing rounds (more local compute per round, carefully), and choosing smaller models or partial updates.

How can FL be integrated with existing machine learning frameworks?

Popular options include TensorFlow Federated (TFF), Flower, FedML, OpenFL, and FATE. In practice, integration usually revolves around (1) a client training loop, (2) a secure aggregation / transport layer, and (3) evaluation and monitoring.

Can FL be combined with other privacy-preserving techniques?

Yes. Secure aggregation and differential privacy are commonly combined. Secure aggregation limits what the server can see; DP provides formal privacy guarantees, typically with clipping + noise and a privacy accountant.

Can FL be used with random forest models?

It can, but it often looks different: clients train local models and the server ensembles predictions or merges structures. Tree-based FL is possible but less “plug-and-play” than neural FedAvg.

Can FL be used with lightGBM models?

It can, typically via federated boosting variants or by ensembling/knowledge-distillation approaches. Communication and model structure complicate naive parameter averaging.

Scroll to Top