Federated Learning Overview¶
The Federated Learning Platform implements a comprehensive federated learning solution using the Flower (Flwr) framework, enabling distributed machine learning across multiple devices while preserving data privacy and ensuring secure model aggregation.
Federated Learning Concepts¶
What is Federated Learning?¶
Federated Learning is a machine learning approach that enables training models across decentralized data sources without requiring data to be centralized. Key principles include:
- Data Privacy: Raw data never leaves client devices
- Distributed Training: Model training occurs on local devices
- Secure Aggregation: Only model updates are shared and aggregated
- Collaborative Learning: Multiple parties contribute to model improvement
Benefits¶
- Privacy Preservation: Sensitive data remains on local devices
- Reduced Bandwidth: Only model parameters are transmitted
- Regulatory Compliance: Meets data protection requirements
- Scalability: Can handle thousands of participating devices
- Robustness: Resilient to individual device failures
Platform Architecture¶
graph TB
subgraph "Orchestrator Node"
WEB_UI[Web Interface<br/>Management Dashboard]
BACKEND[Backend API<br/>Coordination Service]
MONITOR[Monitoring<br/>Grafana + Tempo]
end
subgraph "Flower Framework"
SUPERLINK[Superlink<br/>Communication Hub<br/>Port 9091/9093]
subgraph "Server Side"
AGGREGATOR[Aggregator<br/>ServerApp<br/>Port 9092]
SERVER_LOGIC[Server Logic<br/>server_app.py]
GLOBAL_MODEL[Global Model<br/>Management]
end
subgraph "Client Side"
SUPERNODE1[Supernode 1<br/>Client Manager]
SUPERNODE2[Supernode 2<br/>Client Manager]
CLIENT1[ClientApp 1<br/>Local Training]
CLIENT2[ClientApp 2<br/>Local Training]
CLIENT3[ClientApp 3<br/>Local Training]
CLIENTN[ClientApp N<br/>Local Training]
end
end
subgraph "Data & Models"
DATASET1[Local Dataset 1<br/>Partition 0]
DATASET2[Local Dataset 2<br/>Partition 1]
DATASET3[Local Dataset 3<br/>Partition 2]
DATASETN[Local Dataset N<br/>Partition N]
MODEL_STORAGE[Model Storage<br/>Weights & Checkpoints]
end
subgraph "Observability"
TELEMETRY[OpenTelemetry<br/>Metrics Collection]
TRACES[Distributed Tracing<br/>Training Workflow]
end
WEB_UI --> BACKEND
BACKEND --> SUPERLINK
SUPERLINK --> AGGREGATOR
AGGREGATOR --> SERVER_LOGIC
SERVER_LOGIC --> GLOBAL_MODEL
SUPERLINK --> SUPERNODE1
SUPERLINK --> SUPERNODE2
SUPERNODE1 --> CLIENT1
SUPERNODE1 --> CLIENT2
SUPERNODE2 --> CLIENT3
SUPERNODE2 --> CLIENTN
CLIENT1 --> DATASET1
CLIENT2 --> DATASET2
CLIENT3 --> DATASET3
CLIENTN --> DATASETN
AGGREGATOR --> MODEL_STORAGE
CLIENT1 --> TELEMETRY
CLIENT2 --> TELEMETRY
CLIENT3 --> TELEMETRY
CLIENTN --> TELEMETRY
AGGREGATOR --> TELEMETRY
TELEMETRY --> TRACES
TRACES --> MONITOR
Flower Framework Integration¶
Core Components¶
Superlink¶
The central communication hub that coordinates all federated learning operations.
Configuration:
# Superlink service configuration
superlink:
image: flwr/superlink:1.15.2
command: ["--insecure"]
ports:
- "9091:9091" # Fleet API
- "9093:9093" # SuperLink API
environment:
- FLWR_TELEMETRY_ENABLED=true
Key Responsibilities: - Client registration and management - Round coordination and scheduling - Message routing between server and clients - State management and persistence
ServerApp (Aggregator)¶
Implements the federated learning server logic and model aggregation.
# server_app.py - Federated learning server implementation
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.common import Context
def server_fn(context: Context) -> ServerConfig:
"""Configure federated learning server."""
# Define aggregation strategy
strategy = FedAvg(
fraction_fit=0.8, # Fraction of clients for training
fraction_evaluate=0.8, # Fraction of clients for evaluation
min_fit_clients=2, # Minimum clients for training
min_evaluate_clients=2, # Minimum clients for evaluation
min_available_clients=2, # Minimum available clients
evaluate_metrics_aggregation_fn=weighted_average,
fit_metrics_aggregation_fn=weighted_average,
)
# Server configuration
config = ServerConfig(
num_rounds=10, # Number of federated learning rounds
round_timeout=300, # Timeout per round in seconds
)
return config
# Create ServerApp
app = ServerApp(server_fn=server_fn)
def weighted_average(metrics):
"""Aggregate metrics using weighted average."""
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
return {"accuracy": sum(accuracies) / sum(examples)}
ClientApp¶
Implements local training logic on client devices.
# client_app.py - Federated learning client implementation
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class FlowerClient(NumPyClient):
"""Flower client implementing local training."""
def __init__(self, model, train_loader, test_loader, device):
self.model = model
self.train_loader = train_loader
self.test_loader = test_loader
self.device = device
def fit(self, parameters, config):
"""Train model locally and return updated parameters."""
# Set model parameters
self.set_parameters(parameters)
# Local training
self.model.train()
optimizer = torch.optim.SGD(
self.model.parameters(),
lr=config.get("learning_rate", 0.01)
)
criterion = nn.CrossEntropyLoss()
epochs = config.get("local_epochs", 1)
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(self.train_loader):
data, target = data.to(self.device), target.to(self.device)
optimizer.zero_grad()
output = self.model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Return updated parameters and metrics
return (
self.get_parameters(),
len(self.train_loader.dataset),
{"loss": float(loss)}
)
def evaluate(self, parameters, config):
"""Evaluate model locally and return metrics."""
self.set_parameters(parameters)
self.model.eval()
criterion = nn.CrossEntropyLoss()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in self.test_loader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / len(self.test_loader.dataset)
return float(test_loss), len(self.test_loader.dataset), {"accuracy": accuracy}
def get_parameters(self):
"""Extract model parameters as numpy arrays."""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_parameters(self, parameters):
"""Set model parameters from numpy arrays."""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = {k: torch.tensor(v) for k, v in params_dict}
self.model.load_state_dict(state_dict, strict=True)
def client_fn(context: Context) -> FlowerClient:
"""Create and configure Flower client."""
# Load client configuration
partition_id = context.node_config.get("partition-id", 0)
# Load local data
train_loader, test_loader = load_data(partition_id)
# Initialize model
model = create_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return FlowerClient(model, train_loader, test_loader, device)
# Create ClientApp
app = ClientApp(client_fn=client_fn)
Training Workflow¶
sequenceDiagram
participant User
participant WebUI
participant Backend
participant Superlink
participant Aggregator
participant Supernode
participant Client
participant Monitor
User->>WebUI: Start Training
WebUI->>Backend: POST /training/start
Backend->>Superlink: Initialize FL Session
Superlink->>Aggregator: Start ServerApp
Aggregator->>Aggregator: Initialize Global Model
Backend->>Supernode: Deploy Client Configs
Supernode->>Client: Start ClientApp
Client->>Client: Load Local Data
Client->>Monitor: Log Client Ready
loop Training Rounds (1 to N)
Aggregator->>Superlink: Request Client Selection
Superlink->>Supernode: Select Available Clients
Supernode->>Client: Send Training Request
Client->>Client: Download Global Model
Client->>Client: Local Training (Multiple Epochs)
Client->>Monitor: Log Training Metrics
Client->>Supernode: Send Model Updates
Supernode->>Aggregator: Aggregate Client Updates
Aggregator->>Aggregator: Update Global Model
Aggregator->>Monitor: Log Round Metrics
Aggregator->>Superlink: Broadcast Updated Model
Superlink->>Backend: Round Complete
Backend->>WebUI: Progress Update
WebUI->>User: Display Progress
end
Aggregator->>Backend: Training Complete
Backend->>WebUI: Final Results
WebUI->>User: Show Results
Data Management¶
Dataset Partitioning¶
The platform supports multiple data partitioning strategies:
IID (Independent and Identically Distributed)¶
# generate_dataset.py - IID partitioning
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
def create_iid_partitions(num_clients: int = 10):
"""Create IID data partitions for federated learning."""
# Load dataset
fds = FederatedDataset(
dataset="zalando-datasets/fashion_mnist",
partitioners={"train": IidPartitioner(num_partitions=num_clients)}
)
# Create partitions
partitions = []
for i in range(num_clients):
partition = fds.load_partition(i, "train")
partitions.append(partition)
return partitions
Non-IID Partitioning¶
from flwr_datasets.partitioner import DirichletPartitioner
def create_non_iid_partitions(num_clients: int = 10, alpha: float = 0.1):
"""Create non-IID data partitions using Dirichlet distribution."""
fds = FederatedDataset(
dataset="zalando-datasets/fashion_mnist",
partitioners={
"train": DirichletPartitioner(
num_partitions=num_clients,
partition_by="label",
alpha=alpha # Lower alpha = more non-IID
)
}
)
partitions = []
for i in range(num_clients):
partition = fds.load_partition(i, "train")
partitions.append(partition)
return partitions
Data Privacy Protection¶
Differential Privacy (Future Enhancement)¶
# Differential privacy implementation (planned)
from opacus import PrivacyEngine
import torch.nn as nn
class DifferentiallyPrivateClient(FlowerClient):
"""Client with differential privacy protection."""
def __init__(self, model, train_loader, test_loader, device, privacy_config):
super().__init__(model, train_loader, test_loader, device)
self.privacy_engine = PrivacyEngine()
self.epsilon = privacy_config.get("epsilon", 1.0)
self.delta = privacy_config.get("delta", 1e-5)
def fit(self, parameters, config):
"""Train with differential privacy."""
self.set_parameters(parameters)
# Attach privacy engine
self.model, optimizer, train_loader = self.privacy_engine.make_private(
module=self.model,
optimizer=torch.optim.SGD(self.model.parameters(), lr=0.01),
data_loader=self.train_loader,
noise_multiplier=1.1,
max_grad_norm=1.0,
)
# Training with privacy protection
for epoch in range(config.get("local_epochs", 1)):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = self.model(data)
loss = nn.CrossEntropyLoss()(output, target)
loss.backward()
optimizer.step()
# Check privacy budget
epsilon = self.privacy_engine.get_epsilon(self.delta)
return (
self.get_parameters(),
len(self.train_loader.dataset),
{"loss": float(loss), "epsilon": epsilon}
)
Model Management¶
Model Architecture Support¶
The platform supports various model architectures:
# task.py - Model definitions
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
"""Simple CNN for Fashion-MNIST."""
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class ResNet18(nn.Module):
"""ResNet-18 for image classification."""
def __init__(self, num_classes=10):
super(ResNet18, self).__init__()
# ResNet-18 implementation
pass
def create_model(model_type: str = "cnn", num_classes: int = 10):
"""Factory function to create models."""
if model_type == "cnn":
return SimpleCNN(num_classes)
elif model_type == "resnet18":
return ResNet18(num_classes)
else:
raise ValueError(f"Unsupported model type: {model_type}")
Model Versioning and Storage¶
# Model storage and versioning
import torch
import os
from datetime import datetime
class ModelManager:
"""Manage model storage and versioning."""
def __init__(self, storage_path: str = "/app/models"):
self.storage_path = storage_path
os.makedirs(storage_path, exist_ok=True)
def save_model(self, model, run_id: str, round_num: int, model_type: str = "global"):
"""Save model checkpoint."""
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
filename = f"{run_id}_round_{round_num}_{model_type}_{timestamp}.pth"
filepath = os.path.join(self.storage_path, filename)
checkpoint = {
"model_state_dict": model.state_dict(),
"run_id": run_id,
"round_num": round_num,
"model_type": model_type,
"timestamp": timestamp,
"metadata": {
"framework": "pytorch",
"architecture": model.__class__.__name__
}
}
torch.save(checkpoint, filepath)
return filepath
def load_model(self, model, filepath: str):
"""Load model from checkpoint."""
checkpoint = torch.load(filepath, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
return model, checkpoint["metadata"]
def list_models(self, run_id: str = None):
"""List available model checkpoints."""
models = []
for filename in os.listdir(self.storage_path):
if filename.endswith(".pth"):
if run_id is None or filename.startswith(run_id):
models.append(filename)
return sorted(models)
Performance Optimization¶
Client Selection Strategies¶
# Advanced client selection
from flwr.server.strategy import FedAvg
import numpy as np
class AdaptiveClientSelection(FedAvg):
"""Adaptive client selection based on performance."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.client_performance = {}
def configure_fit(self, server_round, parameters, client_manager):
"""Configure clients for training with adaptive selection."""
# Get available clients
clients = client_manager.all()
if server_round == 1:
# Random selection for first round
selected_clients = np.random.choice(
clients,
size=min(self.min_fit_clients, len(clients)),
replace=False
)
else:
# Performance-based selection
client_scores = []
for client in clients:
score = self.client_performance.get(client.cid, 0.5)
client_scores.append((client, score))
# Sort by performance and select top clients
client_scores.sort(key=lambda x: x[1], reverse=True)
selected_clients = [
client for client, _ in client_scores[:self.min_fit_clients]
]
# Create fit configurations
fit_configs = []
for client in selected_clients:
fit_config = {
"learning_rate": 0.01,
"local_epochs": 1,
"batch_size": 32
}
fit_configs.append((client, fit_config))
return fit_configs
def aggregate_fit(self, server_round, results, failures):
"""Aggregate results and update client performance."""
# Update client performance metrics
for client, fit_res in results:
if "accuracy" in fit_res.metrics:
self.client_performance[client.cid] = fit_res.metrics["accuracy"]
# Standard aggregation
return super().aggregate_fit(server_round, results, failures)
Communication Optimization¶
# Model compression for efficient communication
import torch.nn.utils.prune as prune
class CompressedClient(FlowerClient):
"""Client with model compression."""
def __init__(self, model, train_loader, test_loader, device, compression_ratio=0.1):
super().__init__(model, train_loader, test_loader, device)
self.compression_ratio = compression_ratio
def get_parameters(self):
"""Get compressed model parameters."""
# Apply magnitude-based pruning
for module in self.model.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
prune.l1_unstructured(module, name="weight", amount=self.compression_ratio)
# Extract pruned parameters
parameters = []
for name, param in self.model.named_parameters():
if "weight_orig" in name:
# Get pruned weights
module = dict(self.model.named_modules())[name.replace("_orig", "")]
pruned_weight = getattr(module, name.split(".")[-1])
parameters.append(pruned_weight.cpu().numpy())
elif "weight" not in name or "weight_orig" not in name:
parameters.append(param.cpu().numpy())
return parameters
Next: Continue to Flower Framework for detailed Flower implementation and configuration.