Skip to content

Federated Learning Overview

The Federated Learning Platform implements a comprehensive federated learning solution using the Flower (Flwr) framework, enabling distributed machine learning across multiple devices while preserving data privacy and ensuring secure model aggregation.

Federated Learning Concepts

What is Federated Learning?

Federated Learning is a machine learning approach that enables training models across decentralized data sources without requiring data to be centralized. Key principles include:

  • Data Privacy: Raw data never leaves client devices
  • Distributed Training: Model training occurs on local devices
  • Secure Aggregation: Only model updates are shared and aggregated
  • Collaborative Learning: Multiple parties contribute to model improvement

Benefits

  • Privacy Preservation: Sensitive data remains on local devices
  • Reduced Bandwidth: Only model parameters are transmitted
  • Regulatory Compliance: Meets data protection requirements
  • Scalability: Can handle thousands of participating devices
  • Robustness: Resilient to individual device failures

Platform Architecture

graph TB
    subgraph "Orchestrator Node"
        WEB_UI[Web Interface<br/>Management Dashboard]
        BACKEND[Backend API<br/>Coordination Service]
        MONITOR[Monitoring<br/>Grafana + Tempo]
    end

    subgraph "Flower Framework"
        SUPERLINK[Superlink<br/>Communication Hub<br/>Port 9091/9093]

        subgraph "Server Side"
            AGGREGATOR[Aggregator<br/>ServerApp<br/>Port 9092]
            SERVER_LOGIC[Server Logic<br/>server_app.py]
            GLOBAL_MODEL[Global Model<br/>Management]
        end

        subgraph "Client Side"
            SUPERNODE1[Supernode 1<br/>Client Manager]
            SUPERNODE2[Supernode 2<br/>Client Manager]

            CLIENT1[ClientApp 1<br/>Local Training]
            CLIENT2[ClientApp 2<br/>Local Training]
            CLIENT3[ClientApp 3<br/>Local Training]
            CLIENTN[ClientApp N<br/>Local Training]
        end
    end

    subgraph "Data & Models"
        DATASET1[Local Dataset 1<br/>Partition 0]
        DATASET2[Local Dataset 2<br/>Partition 1]
        DATASET3[Local Dataset 3<br/>Partition 2]
        DATASETN[Local Dataset N<br/>Partition N]

        MODEL_STORAGE[Model Storage<br/>Weights & Checkpoints]
    end

    subgraph "Observability"
        TELEMETRY[OpenTelemetry<br/>Metrics Collection]
        TRACES[Distributed Tracing<br/>Training Workflow]
    end

    WEB_UI --> BACKEND
    BACKEND --> SUPERLINK
    SUPERLINK --> AGGREGATOR
    AGGREGATOR --> SERVER_LOGIC
    SERVER_LOGIC --> GLOBAL_MODEL

    SUPERLINK --> SUPERNODE1
    SUPERLINK --> SUPERNODE2

    SUPERNODE1 --> CLIENT1
    SUPERNODE1 --> CLIENT2
    SUPERNODE2 --> CLIENT3
    SUPERNODE2 --> CLIENTN

    CLIENT1 --> DATASET1
    CLIENT2 --> DATASET2
    CLIENT3 --> DATASET3
    CLIENTN --> DATASETN

    AGGREGATOR --> MODEL_STORAGE

    CLIENT1 --> TELEMETRY
    CLIENT2 --> TELEMETRY
    CLIENT3 --> TELEMETRY
    CLIENTN --> TELEMETRY
    AGGREGATOR --> TELEMETRY

    TELEMETRY --> TRACES
    TRACES --> MONITOR

Flower Framework Integration

Core Components

The central communication hub that coordinates all federated learning operations.

Configuration:

# Superlink service configuration
superlink:
  image: flwr/superlink:1.15.2
  command: ["--insecure"]
  ports:
    - "9091:9091"  # Fleet API
    - "9093:9093"  # SuperLink API
  environment:
    - FLWR_TELEMETRY_ENABLED=true

Key Responsibilities: - Client registration and management - Round coordination and scheduling - Message routing between server and clients - State management and persistence

ServerApp (Aggregator)

Implements the federated learning server logic and model aggregation.

# server_app.py - Federated learning server implementation
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.common import Context

def server_fn(context: Context) -> ServerConfig:
    """Configure federated learning server."""

    # Define aggregation strategy
    strategy = FedAvg(
        fraction_fit=0.8,  # Fraction of clients for training
        fraction_evaluate=0.8,  # Fraction of clients for evaluation
        min_fit_clients=2,  # Minimum clients for training
        min_evaluate_clients=2,  # Minimum clients for evaluation
        min_available_clients=2,  # Minimum available clients
        evaluate_metrics_aggregation_fn=weighted_average,
        fit_metrics_aggregation_fn=weighted_average,
    )

    # Server configuration
    config = ServerConfig(
        num_rounds=10,  # Number of federated learning rounds
        round_timeout=300,  # Timeout per round in seconds
    )

    return config

# Create ServerApp
app = ServerApp(server_fn=server_fn)

def weighted_average(metrics):
    """Aggregate metrics using weighted average."""
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    return {"accuracy": sum(accuracies) / sum(examples)}

ClientApp

Implements local training logic on client devices.

# client_app.py - Federated learning client implementation
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class FlowerClient(NumPyClient):
    """Flower client implementing local training."""

    def __init__(self, model, train_loader, test_loader, device):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device

    def fit(self, parameters, config):
        """Train model locally and return updated parameters."""
        # Set model parameters
        self.set_parameters(parameters)

        # Local training
        self.model.train()
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=config.get("learning_rate", 0.01)
        )
        criterion = nn.CrossEntropyLoss()

        epochs = config.get("local_epochs", 1)
        for epoch in range(epochs):
            for batch_idx, (data, target) in enumerate(self.train_loader):
                data, target = data.to(self.device), target.to(self.device)
                optimizer.zero_grad()
                output = self.model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

        # Return updated parameters and metrics
        return (
            self.get_parameters(),
            len(self.train_loader.dataset),
            {"loss": float(loss)}
        )

    def evaluate(self, parameters, config):
        """Evaluate model locally and return metrics."""
        self.set_parameters(parameters)

        self.model.eval()
        criterion = nn.CrossEntropyLoss()
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / len(self.test_loader.dataset)
        return float(test_loss), len(self.test_loader.dataset), {"accuracy": accuracy}

    def get_parameters(self):
        """Extract model parameters as numpy arrays."""
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        """Set model parameters from numpy arrays."""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

def client_fn(context: Context) -> FlowerClient:
    """Create and configure Flower client."""
    # Load client configuration
    partition_id = context.node_config.get("partition-id", 0)

    # Load local data
    train_loader, test_loader = load_data(partition_id)

    # Initialize model
    model = create_model()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    return FlowerClient(model, train_loader, test_loader, device)

# Create ClientApp
app = ClientApp(client_fn=client_fn)

Training Workflow

sequenceDiagram
    participant User
    participant WebUI
    participant Backend
    participant Superlink
    participant Aggregator
    participant Supernode
    participant Client
    participant Monitor

    User->>WebUI: Start Training
    WebUI->>Backend: POST /training/start
    Backend->>Superlink: Initialize FL Session
    Superlink->>Aggregator: Start ServerApp
    Aggregator->>Aggregator: Initialize Global Model

    Backend->>Supernode: Deploy Client Configs
    Supernode->>Client: Start ClientApp
    Client->>Client: Load Local Data
    Client->>Monitor: Log Client Ready

    loop Training Rounds (1 to N)
        Aggregator->>Superlink: Request Client Selection
        Superlink->>Supernode: Select Available Clients
        Supernode->>Client: Send Training Request

        Client->>Client: Download Global Model
        Client->>Client: Local Training (Multiple Epochs)
        Client->>Monitor: Log Training Metrics
        Client->>Supernode: Send Model Updates

        Supernode->>Aggregator: Aggregate Client Updates
        Aggregator->>Aggregator: Update Global Model
        Aggregator->>Monitor: Log Round Metrics

        Aggregator->>Superlink: Broadcast Updated Model
        Superlink->>Backend: Round Complete
        Backend->>WebUI: Progress Update
        WebUI->>User: Display Progress
    end

    Aggregator->>Backend: Training Complete
    Backend->>WebUI: Final Results
    WebUI->>User: Show Results

Data Management

Dataset Partitioning

The platform supports multiple data partitioning strategies:

IID (Independent and Identically Distributed)

# generate_dataset.py - IID partitioning
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner

def create_iid_partitions(num_clients: int = 10):
    """Create IID data partitions for federated learning."""

    # Load dataset
    fds = FederatedDataset(
        dataset="zalando-datasets/fashion_mnist",
        partitioners={"train": IidPartitioner(num_partitions=num_clients)}
    )

    # Create partitions
    partitions = []
    for i in range(num_clients):
        partition = fds.load_partition(i, "train")
        partitions.append(partition)

    return partitions

Non-IID Partitioning

from flwr_datasets.partitioner import DirichletPartitioner

def create_non_iid_partitions(num_clients: int = 10, alpha: float = 0.1):
    """Create non-IID data partitions using Dirichlet distribution."""

    fds = FederatedDataset(
        dataset="zalando-datasets/fashion_mnist",
        partitioners={
            "train": DirichletPartitioner(
                num_partitions=num_clients,
                partition_by="label",
                alpha=alpha  # Lower alpha = more non-IID
            )
        }
    )

    partitions = []
    for i in range(num_clients):
        partition = fds.load_partition(i, "train")
        partitions.append(partition)

    return partitions

Data Privacy Protection

Differential Privacy (Future Enhancement)

# Differential privacy implementation (planned)
from opacus import PrivacyEngine
import torch.nn as nn

class DifferentiallyPrivateClient(FlowerClient):
    """Client with differential privacy protection."""

    def __init__(self, model, train_loader, test_loader, device, privacy_config):
        super().__init__(model, train_loader, test_loader, device)
        self.privacy_engine = PrivacyEngine()
        self.epsilon = privacy_config.get("epsilon", 1.0)
        self.delta = privacy_config.get("delta", 1e-5)

    def fit(self, parameters, config):
        """Train with differential privacy."""
        self.set_parameters(parameters)

        # Attach privacy engine
        self.model, optimizer, train_loader = self.privacy_engine.make_private(
            module=self.model,
            optimizer=torch.optim.SGD(self.model.parameters(), lr=0.01),
            data_loader=self.train_loader,
            noise_multiplier=1.1,
            max_grad_norm=1.0,
        )

        # Training with privacy protection
        for epoch in range(config.get("local_epochs", 1)):
            for batch_idx, (data, target) in enumerate(train_loader):
                optimizer.zero_grad()
                output = self.model(data)
                loss = nn.CrossEntropyLoss()(output, target)
                loss.backward()
                optimizer.step()

        # Check privacy budget
        epsilon = self.privacy_engine.get_epsilon(self.delta)

        return (
            self.get_parameters(),
            len(self.train_loader.dataset),
            {"loss": float(loss), "epsilon": epsilon}
        )

Model Management

Model Architecture Support

The platform supports various model architectures:

# task.py - Model definitions
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    """Simple CNN for Fashion-MNIST."""

    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

class ResNet18(nn.Module):
    """ResNet-18 for image classification."""

    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        # ResNet-18 implementation
        pass

def create_model(model_type: str = "cnn", num_classes: int = 10):
    """Factory function to create models."""
    if model_type == "cnn":
        return SimpleCNN(num_classes)
    elif model_type == "resnet18":
        return ResNet18(num_classes)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

Model Versioning and Storage

# Model storage and versioning
import torch
import os
from datetime import datetime

class ModelManager:
    """Manage model storage and versioning."""

    def __init__(self, storage_path: str = "/app/models"):
        self.storage_path = storage_path
        os.makedirs(storage_path, exist_ok=True)

    def save_model(self, model, run_id: str, round_num: int, model_type: str = "global"):
        """Save model checkpoint."""
        timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
        filename = f"{run_id}_round_{round_num}_{model_type}_{timestamp}.pth"
        filepath = os.path.join(self.storage_path, filename)

        checkpoint = {
            "model_state_dict": model.state_dict(),
            "run_id": run_id,
            "round_num": round_num,
            "model_type": model_type,
            "timestamp": timestamp,
            "metadata": {
                "framework": "pytorch",
                "architecture": model.__class__.__name__
            }
        }

        torch.save(checkpoint, filepath)
        return filepath

    def load_model(self, model, filepath: str):
        """Load model from checkpoint."""
        checkpoint = torch.load(filepath, map_location="cpu")
        model.load_state_dict(checkpoint["model_state_dict"])
        return model, checkpoint["metadata"]

    def list_models(self, run_id: str = None):
        """List available model checkpoints."""
        models = []
        for filename in os.listdir(self.storage_path):
            if filename.endswith(".pth"):
                if run_id is None or filename.startswith(run_id):
                    models.append(filename)
        return sorted(models)

Performance Optimization

Client Selection Strategies

# Advanced client selection
from flwr.server.strategy import FedAvg
import numpy as np

class AdaptiveClientSelection(FedAvg):
    """Adaptive client selection based on performance."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.client_performance = {}

    def configure_fit(self, server_round, parameters, client_manager):
        """Configure clients for training with adaptive selection."""

        # Get available clients
        clients = client_manager.all()

        if server_round == 1:
            # Random selection for first round
            selected_clients = np.random.choice(
                clients,
                size=min(self.min_fit_clients, len(clients)),
                replace=False
            )
        else:
            # Performance-based selection
            client_scores = []
            for client in clients:
                score = self.client_performance.get(client.cid, 0.5)
                client_scores.append((client, score))

            # Sort by performance and select top clients
            client_scores.sort(key=lambda x: x[1], reverse=True)
            selected_clients = [
                client for client, _ in client_scores[:self.min_fit_clients]
            ]

        # Create fit configurations
        fit_configs = []
        for client in selected_clients:
            fit_config = {
                "learning_rate": 0.01,
                "local_epochs": 1,
                "batch_size": 32
            }
            fit_configs.append((client, fit_config))

        return fit_configs

    def aggregate_fit(self, server_round, results, failures):
        """Aggregate results and update client performance."""

        # Update client performance metrics
        for client, fit_res in results:
            if "accuracy" in fit_res.metrics:
                self.client_performance[client.cid] = fit_res.metrics["accuracy"]

        # Standard aggregation
        return super().aggregate_fit(server_round, results, failures)

Communication Optimization

# Model compression for efficient communication
import torch.nn.utils.prune as prune

class CompressedClient(FlowerClient):
    """Client with model compression."""

    def __init__(self, model, train_loader, test_loader, device, compression_ratio=0.1):
        super().__init__(model, train_loader, test_loader, device)
        self.compression_ratio = compression_ratio

    def get_parameters(self):
        """Get compressed model parameters."""
        # Apply magnitude-based pruning
        for module in self.model.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                prune.l1_unstructured(module, name="weight", amount=self.compression_ratio)

        # Extract pruned parameters
        parameters = []
        for name, param in self.model.named_parameters():
            if "weight_orig" in name:
                # Get pruned weights
                module = dict(self.model.named_modules())[name.replace("_orig", "")]
                pruned_weight = getattr(module, name.split(".")[-1])
                parameters.append(pruned_weight.cpu().numpy())
            elif "weight" not in name or "weight_orig" not in name:
                parameters.append(param.cpu().numpy())

        return parameters

Next: Continue to Flower Framework for detailed Flower implementation and configuration.