(function(w,d,s,l,i){ w[l]=w[l]||[]; w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'}); var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:''; j.async=true; j.src='https://www.googletagmanager.com/gtm.js?id='+i+dl; f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-W24L468');
Federated Learning at Scale: Privacy-Preserving Distributed Training
Polarity:Mixed/Knife-edge

Federated Learning at Scale: Privacy-Preserving Distributed Training

Visual Variations
schnell
kolors

Train ML models across millions of devices without centralizing data. Essential for privacy but vulnerable to poisoning attacks.

Core Architecture

import torch
import torch.nn as nn
from typing import List

class FederatedLearning:
    def __init__(self, global_model, num_clients=1000):
        self.global_model = global_model
        self.num_clients = num_clients

    def federated_averaging(self, client_models: List[nn.Module]):
        """
        FedAvg algorithm: Average model weights from clients.

        ⚠️ Vulnerable to poisoning if client sends malicious gradients
        """
        global_dict = self.global_model.state_dict()

        # Average all client model parameters
        for key in global_dict.keys():
            global_dict[key] = torch.stack([
                client.state_dict()[key].float()
                for client in client_models
            ]).mean(0)

        self.global_model.load_state_dict(global_dict)
        return self.global_model

    def client_update(self, client_data, epochs=5):
        """Local training on client device."""
        model = copy.deepcopy(self.global_model)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

        for epoch in range(epochs):
            for batch in client_data:
                loss = model(batch)
                loss.backward()
                optimizer.step()

        return model
Click to examine closely

"""Prevent server from seeing individual client updates."""

Secure Aggregation

class SecureAggregation:
    """Prevent server from seeing individual client updates."""

    def __init__(self, num_clients):
        self.num_clients = num_clients

    def add_noise(self, gradients):
        """Add differential privacy noise."""
        noise = torch.randn_like(gradients) * 0.1  # σ = 0.1
        return gradients + noise

    def aggregate_securely(self, client_updates):
        """
        Secure multi-party computation.
        Server only sees aggregated result, not individual updates.
        """
        # In production: Use cryptographic techniques (secret sharing)
        # Simplified here for demonstration
        encrypted_updates = [self.encrypt(u) for u in client_updates]
        aggregated = sum(encrypted_updates) / len(encrypted_updates)
        return self.decrypt(aggregated)
Click to examine closely
schnell artwork
schnell

Gradient Poisoning Defense

def byzantine_robust_aggregation(client_gradients, tolerance=0.1):
    """
    Defend against malicious clients sending poisoned gradients.

    ⚠️ Attack: Malicious client sends large gradients to corrupt model
    Defense: Detect and remove outlier gradients
    """
    # Calculate median gradient (robust to outliers)
    stacked = torch.stack([g for g in client_gradients])

    # Remove gradients that are too far from median
    median = torch.median(stacked, dim=0)[0]
    distances = torch.norm(stacked - median, dim=1)

    threshold = torch.quantile(distances, 0.9)  # Remove top 10%
    valid_gradients = [
        g for g, d in zip(client_gradients, distances)
        if d < threshold
    ]

    return torch.mean(torch.stack(valid_gradients), dim=0)
Click to examine closely

Differential Privacy

class DifferentiallyPrivateFedAvg:
    def __init__(self, epsilon=1.0, delta=1e-5):
        """
        epsilon: Privacy budget (lower = more private)
        delta: Failure probability
        """
        self.epsilon = epsilon
        self.delta = delta

    def clip_and_noise(self, gradients, clip_norm=1.0):
        """
        Clip gradients + add Gaussian noise for DP guarantee.
        """
        # Clip each gradient to max norm
        clipped = torch.clamp(gradients, -clip_norm, clip_norm)

        # Add calibrated noise
        sensitivity = 2 * clip_norm  # Max influence of one client
        sigma = sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon

        noise = torch.randn_like(clipped) * sigma
        return clipped + noise
Click to examine closely

Production Deployment

# Coordinator (central server)
class FederatedCoordinator:
    def __init__(self, model, num_rounds=100):
        self.global_model = model
        self.num_rounds = num_rounds

    def train(self, client_pool, clients_per_round=100):
        for round in range(self.num_rounds):
            # Sample clients
            selected = random.sample(client_pool, clients_per_round)

            # Parallel client training
            client_models = []
            for client in selected:
                updated_model = client.train(self.global_model)
                client_models.append(updated_model)

            # Aggregate with Byzantine robustness
            self.global_model = byzantine_robust_aggregation(client_models)

            # Evaluate
            accuracy = self.evaluate()
            print(f"Round {round}: Accuracy = {accuracy:.2%}")

# Client (edge device)
class FederatedClient:
    def __init__(self, local_data):
        self.data = local_data

    def train(self, global_model, epochs=5):
        model = copy.deepcopy(global_model)
        # Train on local data (never sent to server!)
        # ... training loop ...
        return model
Click to examine closely
kolors artwork
kolors

num_clients=100_000_000 # 100M Android devices

Real-World Scale

# Example: Google's Federated Learning (GBoard)
federated_system = FederatedLearning(
    global_model=language_model,
    num_clients=100_000_000  # 100M Android devices
)

# Each device trains locally on user's typing data
# Privacy preserved: raw data never leaves device
# Only model updates aggregated
Click to examine closely

Warnings ⚠️

Poisoning Attacks:

  • Malicious client sends crafted gradients
  • Can backdoor model or degrade performance
  • Defense: Robust aggregation + anomaly detection

Privacy Leakage:

  • Gradients can leak training data (membership inference)
  • Defense: Differential privacy (adds noise)
  • Tradeoff: Privacy vs accuracy

Communication Cost:

  • 100M clients × model size × rounds = massive bandwidth
  • Solution: Gradient compression, quantization

Related Chronicles: Decentralized AI Training Catastrophe (2051)

Tools: Flower, PySyft

AW
Alex Welcing
AI Product Expert
About
Discover related articles and explore the archive