
Polarity:Mixed/Knife-edge
Federated Learning at Scale: Privacy-Preserving Distributed Training
January 12, 2025Sarah Johnson, Distributed ML Engineer4 min read
Visual Variations
schnell
kolors
Train ML models across millions of devices without centralizing data. Essential for privacy but vulnerable to poisoning attacks.
Core Architecture
import torch
import torch.nn as nn
from typing import List
class FederatedLearning:
def __init__(self, global_model, num_clients=1000):
self.global_model = global_model
self.num_clients = num_clients
def federated_averaging(self, client_models: List[nn.Module]):
"""
FedAvg algorithm: Average model weights from clients.
⚠️ Vulnerable to poisoning if client sends malicious gradients
"""
global_dict = self.global_model.state_dict()
# Average all client model parameters
for key in global_dict.keys():
global_dict[key] = torch.stack([
client.state_dict()[key].float()
for client in client_models
]).mean(0)
self.global_model.load_state_dict(global_dict)
return self.global_model
def client_update(self, client_data, epochs=5):
"""Local training on client device."""
model = copy.deepcopy(self.global_model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(epochs):
for batch in client_data:
loss = model(batch)
loss.backward()
optimizer.step()
return model
Click to examine closely"""Prevent server from seeing individual client updates."""
Secure Aggregation
class SecureAggregation:
"""Prevent server from seeing individual client updates."""
def __init__(self, num_clients):
self.num_clients = num_clients
def add_noise(self, gradients):
"""Add differential privacy noise."""
noise = torch.randn_like(gradients) * 0.1 # σ = 0.1
return gradients + noise
def aggregate_securely(self, client_updates):
"""
Secure multi-party computation.
Server only sees aggregated result, not individual updates.
"""
# In production: Use cryptographic techniques (secret sharing)
# Simplified here for demonstration
encrypted_updates = [self.encrypt(u) for u in client_updates]
aggregated = sum(encrypted_updates) / len(encrypted_updates)
return self.decrypt(aggregated)
Click to examine closely
Gradient Poisoning Defense
def byzantine_robust_aggregation(client_gradients, tolerance=0.1):
"""
Defend against malicious clients sending poisoned gradients.
⚠️ Attack: Malicious client sends large gradients to corrupt model
Defense: Detect and remove outlier gradients
"""
# Calculate median gradient (robust to outliers)
stacked = torch.stack([g for g in client_gradients])
# Remove gradients that are too far from median
median = torch.median(stacked, dim=0)[0]
distances = torch.norm(stacked - median, dim=1)
threshold = torch.quantile(distances, 0.9) # Remove top 10%
valid_gradients = [
g for g, d in zip(client_gradients, distances)
if d < threshold
]
return torch.mean(torch.stack(valid_gradients), dim=0)
Click to examine closelyDifferential Privacy
class DifferentiallyPrivateFedAvg:
def __init__(self, epsilon=1.0, delta=1e-5):
"""
epsilon: Privacy budget (lower = more private)
delta: Failure probability
"""
self.epsilon = epsilon
self.delta = delta
def clip_and_noise(self, gradients, clip_norm=1.0):
"""
Clip gradients + add Gaussian noise for DP guarantee.
"""
# Clip each gradient to max norm
clipped = torch.clamp(gradients, -clip_norm, clip_norm)
# Add calibrated noise
sensitivity = 2 * clip_norm # Max influence of one client
sigma = sensitivity * np.sqrt(2 * np.log(1.25 / self.delta)) / self.epsilon
noise = torch.randn_like(clipped) * sigma
return clipped + noise
Click to examine closelyProduction Deployment
# Coordinator (central server)
class FederatedCoordinator:
def __init__(self, model, num_rounds=100):
self.global_model = model
self.num_rounds = num_rounds
def train(self, client_pool, clients_per_round=100):
for round in range(self.num_rounds):
# Sample clients
selected = random.sample(client_pool, clients_per_round)
# Parallel client training
client_models = []
for client in selected:
updated_model = client.train(self.global_model)
client_models.append(updated_model)
# Aggregate with Byzantine robustness
self.global_model = byzantine_robust_aggregation(client_models)
# Evaluate
accuracy = self.evaluate()
print(f"Round {round}: Accuracy = {accuracy:.2%}")
# Client (edge device)
class FederatedClient:
def __init__(self, local_data):
self.data = local_data
def train(self, global_model, epochs=5):
model = copy.deepcopy(global_model)
# Train on local data (never sent to server!)
# ... training loop ...
return model
Click to examine closely
num_clients=100_000_000 # 100M Android devices
Real-World Scale
# Example: Google's Federated Learning (GBoard)
federated_system = FederatedLearning(
global_model=language_model,
num_clients=100_000_000 # 100M Android devices
)
# Each device trains locally on user's typing data
# Privacy preserved: raw data never leaves device
# Only model updates aggregated
Click to examine closelyWarnings ⚠️
Poisoning Attacks:
- Malicious client sends crafted gradients
- Can backdoor model or degrade performance
- Defense: Robust aggregation + anomaly detection
Privacy Leakage:
- Gradients can leak training data (membership inference)
- Defense: Differential privacy (adds noise)
- Tradeoff: Privacy vs accuracy
Communication Cost:
- 100M clients × model size × rounds = massive bandwidth
- Solution: Gradient compression, quantization
Related Chronicles: Decentralized AI Training Catastrophe (2051)