Code Review and Improvement Recommendations¶
After conducting a thorough analysis of the Federated Learning Platform codebase, this document provides specific, actionable recommendations for improving security, performance, reliability, and maintainability.
Executive Summary¶
The codebase demonstrates solid architectural foundations with modern technologies (FastAPI, Next.js, Flower), but several critical areas require immediate attention for production deployment.
Total Issues Found: 54
Critical/High-Priority Issues (27): - Critical Security Issues: 8 (Issues 1-8) - Authentication/Authorization Issues: 5 (Issues 9-13) - Performance Issues: 6 (Issues 14-19) - Frontend Security Issues: 4 (Issues 20-23) - Configuration/Environment Issues: 4 (Issues 24-27)
Medium/Low-Priority Issues (27): - Testing and Quality Issues: 8 (Issues 28-35) - Operational/Monitoring Issues: 6 (Issues 36-41) - Docker/Deployment Issues: 5 (Issues 42-46) - Code Quality Issues: 4 (Issues 47-50) - Additional Security Hardening: 4 (Issues 51-54)
Critical Security Issues (Issues 1-8)¶
1. Hardcoded Secrets and Weak Defaults¶
Location: backend/app/database/mongodb.py lines 9-10, backend/app/services/auth_service.py line 17
Risk: High - Unauthorized database access, JWT token compromise
Current Code:
# backend/app/database/mongodb.py
DEFAULT_MONGO_USERNAME= "admin"
DEFAULT_MONGO_PASSWORD= "password"
# backend/app/services/auth_service.py
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-in-production")
Recommendation:
# Secure implementation
SECRET_KEY = os.getenv("SECRET_KEY")
if not SECRET_KEY:
raise ValueError("SECRET_KEY environment variable must be set")
# Remove hardcoded defaults
mongo_username = os.getenv("MONGO_DB_USERNAME")
mongo_password = os.getenv("MONGO_DB_PASSWORD")
if not mongo_username or not mongo_password:
raise ValueError("Database credentials must be provided via environment variables")
2. Insecure Password Hashing Implementation¶
Location: backend/app/services/auth_service.py lines 21-35
Risk: High - Password compromise via rainbow table attacks
Current Code:
def hash_password(password: str, salt: str) -> str:
return hashlib.sha256((password + salt).encode()).hexdigest()
Recommendation:
from passlib.context import CryptContext
pwd_context = CryptContext(
schemes=["bcrypt"],
deprecated="auto",
bcrypt__rounds=12 # Strong hashing
)
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
3. Missing Input Validation and Rate Limiting¶
Location: All API endpoints Risk: High - DoS attacks, injection vulnerabilities
Recommendation:
# Add rate limiting
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add input validation models
from pydantic import BaseModel, validator
class TrainingJobCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
@validator('name')
def validate_name(cls, v):
if not re.match(r'^[a-zA-Z0-9_\-\s]+$', v):
raise ValueError('Name contains invalid characters')
return v
@router.post("/training/jobs")
@limiter.limit("10/minute")
async def create_training_job(
request: Request,
job_data: TrainingJobCreate, # Validated input
current_user: User = Depends(get_current_user)
):
# Implementation
4. Insecure CORS Configuration¶
Location: backend/app/main.py lines 71-84
Risk: Medium - Cross-origin attacks
Current Code:
cors_config = {
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "HEAD"],
"allow_headers": ["*"], # Too permissive
"expose_headers": ["*"], # Too permissive
}
Recommendation:
cors_config = {
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": [
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization"
],
"expose_headers": ["Content-Length", "Content-Type"],
"max_age": 3600
}
5. Missing Security Headers¶
Location: backend/app/main.py
Risk: Medium - XSS, clickjacking vulnerabilities
Recommendation:
from fastapi import Request
from fastapi.responses import Response
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'"
return response
6. Insufficient Error Handling¶
Location: Throughout codebase Risk: Medium - Information disclosure
Recommendation:
from fastapi import HTTPException
from fastapi.responses import JSONResponse
import logging
class APIError(Exception):
def __init__(self, status_code: int, detail: str, error_code: str = None):
self.status_code = status_code
self.detail = detail
self.error_code = error_code
@app.exception_handler(APIError)
async def api_error_handler(request: Request, exc: APIError):
# Log error details for debugging (sanitized)
logger.error(f"API Error: {exc.error_code} - {exc.detail}")
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"code": exc.error_code,
"message": exc.detail,
"timestamp": datetime.utcnow().isoformat()
}
}
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
# Log full error for debugging
logger.exception("Unhandled exception occurred")
# Return generic error to user
return JSONResponse(
status_code=500,
content={
"error": {
"code": "INTERNAL_ERROR",
"message": "An internal error occurred",
"timestamp": datetime.utcnow().isoformat()
}
}
)
7. JWT Token Exposure in Logs¶
Location: backend/app/main.py lines 106-109
Risk: High - Token compromise
Current Code:
Recommendation:
# Sanitize sensitive data before logging
def sanitize_config_for_logging(config):
sanitized = config.copy()
# Remove or mask sensitive fields
if 'secret_key' in sanitized:
sanitized['secret_key'] = '***REDACTED***'
return sanitized
logger.info(f"CORS Configuration: {sanitize_config_for_logging(cors_config)}")
# Never log tokens, passwords, or other secrets
def sanitize_request_data(data):
sensitive_fields = ['password', 'token', 'secret', 'key', 'authorization']
sanitized = data.copy()
for field in sensitive_fields:
if field in sanitized:
sanitized[field] = '***REDACTED***'
return sanitized
8. Database Connection Security¶
Location: backend/app/database/mongodb.py
Risk: Medium - Connection hijacking
Recommendation:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.errors import ServerSelectionTimeoutError
import ssl
class MongoDB:
client: Optional[AsyncIOMotorClient] = None
@classmethod
async def connect_to_mongodb(cls):
mongodb_url = os.getenv("MONGODB_URL")
if not mongodb_url:
raise ValueError("MONGODB_URL environment variable required")
# Security configurations
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False # Only for development
try:
cls.client = AsyncIOMotorClient(
mongodb_url,
serverSelectionTimeoutMS=5000,
connectTimeoutMS=10000,
socketTimeoutMS=10000,
maxPoolSize=50, # Connection pooling
minPoolSize=5,
maxIdleTimeMS=30000,
ssl=True, # Enable TLS
ssl_context=ssl_context,
retryWrites=True,
retryReads=True
)
# Test connection
await cls.client.admin.command('ping')
logger.info("Secure MongoDB connection established")
except ServerSelectionTimeoutError:
logger.error("Failed to connect to MongoDB")
raise
Authentication/Authorization Issues (Issues 9-13)¶
9. Missing Token Revocation¶
Location: JWT implementation Risk: High - Compromised tokens remain valid
Recommendation:
import redis
from datetime import datetime, timedelta
class TokenBlacklist:
def __init__(self, redis_client):
self.redis = redis_client
async def revoke_token(self, token: str, expires_at: datetime):
# Calculate TTL based on token expiration
ttl = int((expires_at - datetime.utcnow()).total_seconds())
if ttl > 0:
await self.redis.setex(f"blacklist:{token}", ttl, "revoked")
async def is_revoked(self, token: str) -> bool:
return await self.redis.exists(f"blacklist:{token}")
# Usage in auth verification
async def verify_token(token: str):
if await token_blacklist.is_revoked(token):
raise HTTPException(status_code=401, detail="Token has been revoked")
# Continue with normal JWT verification
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
10. Weak Password Requirements¶
Location: backend/app/models/user.py lines 31-41
Risk: Medium - Weak user passwords
Current Code:
@validator('password')
def validate_password(cls, v):
if len(v) < 8:
raise ValueError('Password must be at least 8 characters long')
# Basic checks only
Recommendation:
import re
from zxcvbn import zxcvbn
@validator('password')
def validate_password(cls, v):
# Length requirement
if len(v) < 12:
raise ValueError('Password must be at least 12 characters long')
# Complexity requirements
if not re.search(r'[A-Z]', v):
raise ValueError('Password must contain at least one uppercase letter')
if not re.search(r'[a-z]', v):
raise ValueError('Password must contain at least one lowercase letter')
if not re.search(r'\d', v):
raise ValueError('Password must contain at least one digit')
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', v):
raise ValueError('Password must contain at least one special character')
# Check against common passwords
strength = zxcvbn(v)
if strength['score'] < 3:
raise ValueError('Password is too weak. Please choose a stronger password.')
# Check for common patterns
if re.search(r'(.)\1{2,}', v): # Repeated characters
raise ValueError('Password cannot contain repeated characters')
return v
11. Missing Rate Limiting on Authentication¶
Location: backend/app/routes/auth.py
Risk: High - Brute force attacks
Recommendation:
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
limiter = Limiter(key_func=get_remote_address)
# Failed attempt tracking
failed_attempts = {}
@router.post("/login")
@limiter.limit("5/minute") # 5 attempts per minute per IP
async def login(
request: Request,
user_credentials: UserLogin,
db=Depends(get_database)
):
client_ip = get_remote_address(request)
# Check if IP is temporarily blocked
if client_ip in failed_attempts:
if failed_attempts[client_ip]['count'] >= 5:
if datetime.utcnow() < failed_attempts[client_ip]['blocked_until']:
raise HTTPException(
status_code=429,
detail="Too many failed attempts. Try again later."
)
else:
# Reset after block period
del failed_attempts[client_ip]
try:
user = await UserService.authenticate_user(
db, user_credentials.username, user_credentials.password
)
if not user:
# Track failed attempt
if client_ip not in failed_attempts:
failed_attempts[client_ip] = {'count': 0, 'blocked_until': None}
failed_attempts[client_ip]['count'] += 1
if failed_attempts[client_ip]['count'] >= 5:
# Block for 15 minutes
failed_attempts[client_ip]['blocked_until'] = datetime.utcnow() + timedelta(minutes=15)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password"
)
# Clear failed attempts on successful login
if client_ip in failed_attempts:
del failed_attempts[client_ip]
# Generate token
access_token = AuthService.create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}
except Exception as e:
logger.error(f"Login error for {user_credentials.username}: {str(e)}")
raise HTTPException(status_code=500, detail="Authentication failed")
12. Authorization Bypass in Optional Auth¶
Location: backend/app/core/auth.py lines 93-114
Risk: Medium - Unauthorized access
Current Code:
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncIOMotorDatabase = Depends(get_database)
) -> Optional[User]:
if not credentials:
return None
# Potential bypass issues
Recommendation:
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncIOMotorDatabase = Depends(get_database)
) -> Optional[User]:
"""Get current user if token is provided and valid, otherwise return None.
IMPORTANT: Endpoints using this dependency MUST implement their own
authorization checks for sensitive operations.
"""
if not credentials:
return None
try:
# Strict token validation
token_data = AuthService.verify_token(credentials.credentials)
if not token_data or not token_data.username:
logger.warning(f"Invalid token format in optional auth")
return None
# Check token blacklist
if await token_blacklist.is_revoked(credentials.credentials):
logger.warning(f"Revoked token used in optional auth")
return None
# Get user and verify active status
user = await UserService.get_user_by_id(db, token_data.user_id)
if not user or not user.is_active:
logger.warning(f"Inactive user attempted access: {token_data.user_id}")
return None
return user
except Exception as e:
logger.warning(f"Optional auth failed: {str(e)}")
return None
# Audit function for endpoints using optional auth
def require_auth_for_sensitive_ops(user: Optional[User], operation: str):
"""Helper to ensure sensitive operations require authentication."""
if user is None:
raise HTTPException(
status_code=401,
detail=f"Authentication required for {operation}"
)
return user
13. Session Management Issues¶
Location: JWT implementation Risk: Medium - Session fixation
Recommendation:
from datetime import datetime, timedelta
import secrets
class SessionManager:
def __init__(self, redis_client):
self.redis = redis_client
self.session_timeout = timedelta(hours=24)
self.refresh_threshold = timedelta(hours=1)
async def create_session(self, user_id: str, device_info: dict) -> dict:
"""Create a new session with refresh token."""
session_id = secrets.token_urlsafe(32)
refresh_token = secrets.token_urlsafe(32)
session_data = {
"user_id": user_id,
"session_id": session_id,
"device_info": device_info,
"created_at": datetime.utcnow().isoformat(),
"last_activity": datetime.utcnow().isoformat()
}
# Store session
await self.redis.setex(
f"session:{session_id}",
int(self.session_timeout.total_seconds()),
json.dumps(session_data)
)
# Store refresh token
await self.redis.setex(
f"refresh:{refresh_token}",
int(self.session_timeout.total_seconds()),
session_id
)
return {
"session_id": session_id,
"refresh_token": refresh_token
}
async def refresh_session(self, refresh_token: str) -> Optional[dict]:
"""Refresh session and generate new tokens."""
session_id = await self.redis.get(f"refresh:{refresh_token}")
if not session_id:
return None
session_data = await self.redis.get(f"session:{session_id}")
if not session_data:
return None
# Generate new tokens
new_session_id = secrets.token_urlsafe(32)
new_refresh_token = secrets.token_urlsafe(32)
# Update session data
session_info = json.loads(session_data)
session_info["session_id"] = new_session_id
session_info["last_activity"] = datetime.utcnow().isoformat()
# Store new session
await self.redis.setex(
f"session:{new_session_id}",
int(self.session_timeout.total_seconds()),
json.dumps(session_info)
)
# Store new refresh token
await self.redis.setex(
f"refresh:{new_refresh_token}",
int(self.session_timeout.total_seconds()),
new_session_id
)
# Invalidate old tokens
await self.redis.delete(f"session:{session_id}")
await self.redis.delete(f"refresh:{refresh_token}")
return {
"session_id": new_session_id,
"refresh_token": new_refresh_token,
"user_id": session_info["user_id"]
}
async def invalidate_session(self, session_id: str):
"""Invalidate a specific session."""
await self.redis.delete(f"session:{session_id}")
async def invalidate_all_user_sessions(self, user_id: str):
"""Invalidate all sessions for a user."""
# This would require indexing sessions by user_id
# Implementation depends on Redis setup
pass
Performance Issues (Issues 14-19)¶
14. Database Connection Management¶
Location: backend/app/database/mongodb.py
Risk: High - Connection exhaustion
Recommendation:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.errors import ServerSelectionTimeoutError
import asyncio
class MongoDB:
client: Optional[AsyncIOMotorClient] = None
connection_pool = None
@classmethod
async def connect_to_mongodb(cls):
mongodb_url = os.getenv("MONGODB_URL")
try:
cls.client = AsyncIOMotorClient(
mongodb_url,
# Connection pool settings
maxPoolSize=50, # Maximum connections
minPoolSize=10, # Minimum connections
maxIdleTimeMS=30000, # Close idle connections after 30s
waitQueueTimeoutMS=5000, # Wait 5s for connection from pool
serverSelectionTimeoutMS=5000,
connectTimeoutMS=10000,
socketTimeoutMS=10000,
# Retry settings
retryWrites=True,
retryReads=True,
# Monitoring
heartbeatFrequencyMS=10000,
serverSelectionTimeoutMS=5000
)
# Test connection
await cls.client.admin.command('ping')
logger.info(f"MongoDB connected with pool size: {cls.client.max_pool_size}")
except Exception as e:
logger.error(f"MongoDB connection failed: {e}")
raise
@classmethod
async def get_connection_stats(cls):
"""Get connection pool statistics."""
if cls.client:
return {
"max_pool_size": cls.client.max_pool_size,
"min_pool_size": cls.client.min_pool_size,
"active_connections": len(cls.client._topology._servers),
}
return None
15. Missing Database Indexes¶
Location: MongoDB collections Risk: High - Slow queries
Recommendation:
async def create_indexes(db):
"""Create database indexes for optimal query performance."""
# Users collection indexes
await db.users.create_index("username", unique=True, name="username_unique_idx")
await db.users.create_index("email", unique=True, sparse=True, name="email_unique_idx")
await db.users.create_index([("is_active", 1), ("created_at", -1)], name="active_users_idx")
await db.users.create_index("roles", name="user_roles_idx")
# Training jobs collection indexes
await db.training_jobs.create_index([("user_id", 1), ("status", 1)], name="user_jobs_status_idx")
await db.training_jobs.create_index([("created_at", -1)], name="jobs_created_desc_idx")
await db.training_jobs.create_index([("status", 1), ("updated_at", -1)], name="jobs_status_updated_idx")
await db.training_jobs.create_index("project_id", name="jobs_project_idx")
# Projects collection indexes
await db.projects.create_index([("owner_id", 1), ("created_at", -1)], name="project_owner_idx")
await db.projects.create_index("name", name="project_name_idx")
await db.projects.create_index([("is_active", 1), ("updated_at", -1)], name="active_projects_idx")
# Configs collection indexes
await db.configs.create_index([("project_id", 1), ("type", 1)], name="project_config_type_idx")
await db.configs.create_index("created_at", name="config_created_idx")
# Ansible jobs collection indexes
await db.ansible_jobs.create_index([("status", 1), ("created_at", -1)], name="ansible_status_created_idx")
await db.ansible_jobs.create_index("user_id", name="ansible_user_idx")
# Compound indexes for complex queries
await db.training_jobs.create_index([
("user_id", 1),
("project_id", 1),
("status", 1),
("created_at", -1)
], name="jobs_complex_query_idx")
logger.info("Database indexes created successfully")
# Query optimization examples
async def get_user_jobs_optimized(db, user_id: str, status: str = None, limit: int = 20):
"""Optimized query using indexes."""
query = {"user_id": user_id}
if status:
query["status"] = status
# This will use the user_jobs_status_idx index
cursor = db.training_jobs.find(query).sort("created_at", -1).limit(limit)
return await cursor.to_list(length=limit)
16. No Caching Implementation¶
Location: API responses Risk: Medium - Poor performance
Recommendation:
import redis.asyncio as redis
import json
from functools import wraps
import hashlib
class CacheManager:
def __init__(self, redis_url: str):
self.redis = redis.from_url(redis_url)
self.default_ttl = 300 # 5 minutes
async def get(self, key: str):
"""Get cached value."""
try:
value = await self.redis.get(key)
return json.loads(value) if value else None
except Exception as e:
logger.error(f"Cache get error: {e}")
return None
async def set(self, key: str, value, ttl: int = None):
"""Set cached value."""
try:
ttl = ttl or self.default_ttl
await self.redis.setex(key, ttl, json.dumps(value, default=str))
except Exception as e:
logger.error(f"Cache set error: {e}")
async def delete(self, key: str):
"""Delete cached value."""
try:
await self.redis.delete(key)
except Exception as e:
logger.error(f"Cache delete error: {e}")
async def invalidate_pattern(self, pattern: str):
"""Invalidate all keys matching pattern."""
try:
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)
except Exception as e:
logger.error(f"Cache invalidate error: {e}")
# Cache decorator
def cache_response(ttl: int = 300, key_prefix: str = ""):
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Generate cache key
cache_key = f"{key_prefix}:{func.__name__}:{hashlib.md5(str(args + tuple(kwargs.items())).encode()).hexdigest()}"
# Try to get from cache
cached_result = await cache_manager.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = await func(*args, **kwargs)
await cache_manager.set(cache_key, result, ttl)
return result
return wrapper
return decorator
# Usage examples
@router.get("/projects")
@cache_response(ttl=600, key_prefix="projects")
async def get_projects(
current_user: User = Depends(get_current_user),
db = Depends(get_database)
):
# This response will be cached for 10 minutes
projects = await ProjectService.get_user_projects(db, current_user.id)
return projects
@router.get("/training/jobs/{job_id}")
@cache_response(ttl=60, key_prefix="job_details")
async def get_job_details(
job_id: str,
current_user: User = Depends(get_current_user),
db = Depends(get_database)
):
# Cache job details for 1 minute
job = await TrainingService.get_job(db, job_id)
return job
# Cache invalidation on updates
@router.put("/projects/{project_id}")
async def update_project(
project_id: str,
project_data: ProjectUpdate,
current_user: User = Depends(get_current_user),
db = Depends(get_database)
):
# Update project
updated_project = await ProjectService.update_project(db, project_id, project_data)
# Invalidate related caches
await cache_manager.invalidate_pattern(f"projects:*")
await cache_manager.invalidate_pattern(f"project_details:{project_id}:*")
return updated_project
17. Inefficient Query Patterns¶
Location: Database queries Risk: Medium - Slow response times
Recommendation:
# Bad: N+1 query problem
async def get_jobs_with_projects_bad(db, user_id: str):
jobs = await db.training_jobs.find({"user_id": user_id}).to_list(length=100)
for job in jobs:
# This creates N additional queries!
project = await db.projects.find_one({"_id": job["project_id"]})
job["project"] = project
return jobs
# Good: Use aggregation pipeline
async def get_jobs_with_projects_optimized(db, user_id: str):
pipeline = [
{"$match": {"user_id": user_id}},
{"$lookup": {
"from": "projects",
"localField": "project_id",
"foreignField": "_id",
"as": "project"
}},
{"$unwind": "$project"},
{"$sort": {"created_at": -1}},
{"$limit": 100},
{"$project": {
"_id": 1,
"name": 1,
"status": 1,
"created_at": 1,
"project.name": 1,
"project.description": 1
}}
]
cursor = db.training_jobs.aggregate(pipeline)
return await cursor.to_list(length=100)
# Efficient pagination
async def get_jobs_paginated(db, user_id: str, page: int = 1, limit: int = 20):
skip = (page - 1) * limit
# Get total count efficiently
total_count = await db.training_jobs.count_documents({"user_id": user_id})
# Get paginated results
pipeline = [
{"$match": {"user_id": user_id}},
{"$sort": {"created_at": -1}},
{"$skip": skip},
{"$limit": limit},
{"$lookup": {
"from": "projects",
"localField": "project_id",
"foreignField": "_id",
"as": "project"
}},
{"$unwind": "$project"}
]
cursor = db.training_jobs.aggregate(pipeline)
jobs = await cursor.to_list(length=limit)
return {
"jobs": jobs,
"total": total_count,
"page": page,
"pages": (total_count + limit - 1) // limit
}
# Bulk operations for better performance
async def update_multiple_jobs_status(db, job_ids: List[str], status: str):
"""Update multiple jobs in a single operation."""
result = await db.training_jobs.update_many(
{"_id": {"$in": [ObjectId(job_id) for job_id in job_ids]}},
{
"$set": {
"status": status,
"updated_at": datetime.utcnow()
}
}
)
return result.modified_count
18. Memory Leaks in WebSocket Connections¶
Location: frontend/src/services/WebSocketManager.ts
Risk: Medium - Memory exhaustion
Recommendation:
class WebSocketManager {
private ws: WebSocket | null = null;
private subscribers = new Map<string, Subscription>();
private reconnectTimeout: NodeJS.Timeout | null = null;
private heartbeatInterval: NodeJS.Timeout | null = null;
private connectionStatus: 'disconnected' | 'connecting' | 'connected' | 'error' = 'disconnected';
private maxReconnectAttempts = 5;
private reconnectAttempts = 0;
private reconnectInterval = 5000;
// Memory leak prevention
private messageQueue: any[] = [];
private maxQueueSize = 1000;
private cleanupInterval: NodeJS.Timeout | null = null;
constructor() {
// Start cleanup routine
this.startCleanupRoutine();
// Cleanup on page unload
if (typeof window !== 'undefined') {
window.addEventListener('beforeunload', this.cleanup.bind(this));
window.addEventListener('pagehide', this.cleanup.bind(this));
}
}
private startCleanupRoutine() {
this.cleanupInterval = setInterval(() => {
this.cleanupStaleSubscriptions();
this.limitMessageQueue();
}, 30000); // Cleanup every 30 seconds
}
private cleanupStaleSubscriptions() {
const now = Date.now();
const staleThreshold = 5 * 60 * 1000; // 5 minutes
for (const [id, subscription] of this.subscribers.entries()) {
if (now - subscription.lastActivity > staleThreshold) {
console.log(`Removing stale subscription: ${id}`);
this.subscribers.delete(id);
}
}
}
private limitMessageQueue() {
if (this.messageQueue.length > this.maxQueueSize) {
// Remove oldest messages
this.messageQueue = this.messageQueue.slice(-this.maxQueueSize / 2);
console.warn('Message queue size limited to prevent memory leak');
}
}
private handleMessage = (event: MessageEvent) => {
try {
const data = JSON.parse(event.data);
// Add to queue with size limit
this.messageQueue.push({
data,
timestamp: Date.now()
});
if (this.messageQueue.length > this.maxQueueSize) {
this.messageQueue.shift(); // Remove oldest
}
// Notify subscribers
for (const [id, subscription] of this.subscribers.entries()) {
try {
if (this.matchesFilter(data, subscription.filter)) {
subscription.callback(data);
subscription.lastActivity = Date.now();
}
} catch (error) {
console.error(`Error in subscription ${id}:`, error);
// Remove problematic subscription
this.subscribers.delete(id);
}
}
} catch (error) {
console.error('Error parsing WebSocket message:', error);
}
};
private startHeartbeat() {
this.stopHeartbeat();
this.heartbeatInterval = setInterval(() => {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify({ type: 'ping' }));
}
}, 30000); // Ping every 30 seconds
}
private stopHeartbeat() {
if (this.heartbeatInterval) {
clearInterval(this.heartbeatInterval);
this.heartbeatInterval = null;
}
}
subscribe(callback: JobSubscriber, filter?: SubscriptionFilter): string {
const subscriptionId = `sub_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
this.subscribers.set(subscriptionId, {
id: subscriptionId,
callback,
filter,
lastActivity: Date.now()
});
// Connect if this is the first subscription
if (this.subscribers.size === 1 && this.connectionStatus === 'disconnected') {
this.connect();
}
return subscriptionId;
}
unsubscribe(subscriptionId: string): void {
this.subscribers.delete(subscriptionId);
// Disconnect if no more subscribers
if (this.subscribers.size === 0) {
this.disconnect();
}
}
private cleanup() {
console.log('Cleaning up WebSocket manager');
// Clear all timers
this.stopHeartbeat();
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
// Clear subscribers
this.subscribers.clear();
// Clear message queue
this.messageQueue = [];
// Close WebSocket
if (this.ws) {
this.ws.onopen = null;
this.ws.onmessage = null;
this.ws.onclose = null;
this.ws.onerror = null;
this.ws.close();
this.ws = null;
}
this.connectionStatus = 'disconnected';
}
// Singleton pattern to prevent multiple instances
private static instance: WebSocketManager | null = null;
static getInstance(): WebSocketManager {
if (!WebSocketManager.instance) {
WebSocketManager.instance = new WebSocketManager();
}
return WebSocketManager.instance;
}
static cleanup() {
if (WebSocketManager.instance) {
WebSocketManager.instance.cleanup();
WebSocketManager.instance = null;
}
}
}
// Export singleton instance
export const websocketManager = WebSocketManager.getInstance();
// Cleanup on module unload
if (typeof window !== 'undefined') {
window.addEventListener('beforeunload', () => {
WebSocketManager.cleanup();
});
}
19. No Connection Pooling¶
Location: External service calls Risk: Medium - Resource exhaustion
Recommendation:
import aiohttp
import asyncio
from typing import Optional
class HTTPConnectionPool:
def __init__(self, max_connections: int = 100, max_connections_per_host: int = 30):
self.connector = aiohttp.TCPConnector(
limit=max_connections,
limit_per_host=max_connections_per_host,
ttl_dns_cache=300, # DNS cache TTL
use_dns_cache=True,
keepalive_timeout=30,
enable_cleanup_closed=True
)
self.timeout = aiohttp.ClientTimeout(
total=30, # Total timeout
connect=10, # Connection timeout
sock_read=10 # Socket read timeout
)
self.session: Optional[aiohttp.ClientSession] = None
async def get_session(self) -> aiohttp.ClientSession:
"""Get or create HTTP session with connection pooling."""
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(
connector=self.connector,
timeout=self.timeout,
headers={
'User-Agent': 'FederatedLearningPlatform/1.0',
'Accept': 'application/json',
'Connection': 'keep-alive'
}
)
return self.session
async def close(self):
"""Close the session and connector."""
if self.session and not self.session.closed:
await self.session.close()
await self.connector.close()
# Global connection pool
http_pool = HTTPConnectionPool()
class ExternalAPIClient:
def __init__(self):
self.base_urls = {
'flower': os.getenv('FLOWER_API_URL', 'http://localhost:9093'),
'monitoring': os.getenv('MONITORING_URL', 'http://localhost:3001')
}
async def make_request(
self,
method: str,
url: str,
service: str = 'flower',
**kwargs
):
"""Make HTTP request using connection pool."""
session = await http_pool.get_session()
full_url = f"{self.base_urls[service]}{url}"
try:
async with session.request(method, full_url, **kwargs) as response:
response.raise_for_status()
return await response.json()
except aiohttp.ClientError as e:
logger.error(f"HTTP request failed: {e}")
raise
except asyncio.TimeoutError:
logger.error(f"Request timeout: {full_url}")
raise
async def get_flower_status(self):
"""Get Flower server status using connection pool."""
return await self.make_request('GET', '/api/v1/status', 'flower')
async def submit_training_job(self, job_data: dict):
"""Submit training job using connection pool."""
return await self.make_request(
'POST',
'/api/v1/jobs',
'flower',
json=job_data
)
# Cleanup on application shutdown
async def cleanup_http_pool():
"""Cleanup HTTP connection pool."""
await http_pool.close()
# Usage in FastAPI
@app.on_event("shutdown")
async def shutdown_event():
await cleanup_http_pool()
logger.info("HTTP connection pool closed")
Frontend Security Issues (Issues 20-23)¶
20. Insecure Token Storage¶
Location: Frontend localStorage usage Risk: High - XSS token theft
Current Code:
// Vulnerable: storing JWT in localStorage
localStorage.setItem('token', accessToken);
const token = localStorage.getItem('token');
Recommendation:
// Secure: Use httpOnly cookies
class AuthService {
async login(credentials: LoginCredentials): Promise<void> {
const response = await fetch('/api/auth/login', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(credentials),
credentials: 'include' // Include cookies
});
if (response.ok) {
// Token is set as httpOnly cookie by server
// No client-side token storage needed
this.setAuthState(true);
}
}
async makeAuthenticatedRequest(url: string, options: RequestInit = {}) {
return fetch(url, {
...options,
credentials: 'include', // Include httpOnly cookies
headers: {
...options.headers,
'X-Requested-With': 'XMLHttpRequest' // CSRF protection
}
});
}
async logout(): Promise<void> {
await fetch('/api/auth/logout', {
method: 'POST',
credentials: 'include'
});
this.setAuthState(false);
}
}
// Backend implementation for httpOnly cookies
@router.post("/login")
async def login(
response: Response,
user_credentials: UserLogin,
db=Depends(get_database)
):
user = await UserService.authenticate_user(db, user_credentials.username, user_credentials.password)
if not user:
raise HTTPException(status_code=401, detail="Invalid credentials")
access_token = AuthService.create_access_token(data={"sub": user.username})
# Set httpOnly cookie
response.set_cookie(
key="access_token",
value=access_token,
httponly=True, # Prevents XSS access
secure=True, # HTTPS only
samesite="strict", # CSRF protection
max_age=1800 # 30 minutes
)
return {"message": "Login successful"}
21. Missing Content Security Policy¶
Location: Frontend headers Risk: Medium - XSS attacks
Recommendation:
// next.config.js - Add CSP headers
const ContentSecurityPolicy = `
default-src 'self';
script-src 'self' 'unsafe-eval' 'unsafe-inline';
style-src 'self' 'unsafe-inline';
img-src 'self' data: https:;
font-src 'self';
object-src 'none';
base-uri 'self';
form-action 'self';
frame-ancestors 'none';
block-all-mixed-content;
upgrade-insecure-requests;
`;
const securityHeaders = [
{
key: 'Content-Security-Policy',
value: ContentSecurityPolicy.replace(/\s{2,}/g, ' ').trim()
},
{
key: 'Referrer-Policy',
value: 'strict-origin-when-cross-origin'
},
{
key: 'X-Frame-Options',
value: 'DENY'
},
{
key: 'X-Content-Type-Options',
value: 'nosniff'
},
{
key: 'X-DNS-Prefetch-Control',
value: 'false'
},
{
key: 'Strict-Transport-Security',
value: 'max-age=31536000; includeSubDomains; preload'
},
{
key: 'Permissions-Policy',
value: 'camera=(), microphone=(), geolocation=()'
}
];
module.exports = {
async headers() {
return [
{
source: '/(.*)',
headers: securityHeaders,
},
];
},
};
// Runtime CSP for dynamic content
export function CSPProvider({ children }: { children: React.ReactNode }) {
useEffect(() => {
// Generate nonce for inline scripts
const nonce = crypto.randomUUID();
// Update CSP with nonce
const meta = document.createElement('meta');
meta.httpEquiv = 'Content-Security-Policy';
meta.content = `script-src 'self' 'nonce-${nonce}'; object-src 'none';`;
document.head.appendChild(meta);
return () => {
document.head.removeChild(meta);
};
}, []);
return <>{children}</>;
}
22. File Upload Vulnerabilities¶
Location: frontend/src/app/ansible/components/ui/FileUpload.tsx
Risk: High - Malicious file uploads
Current Code:
Recommendation:
// Frontend: Enhanced validation
interface FileUploadProps {
maxFileSize?: number;
allowedTypes?: string[];
onUploadProgress?: (progress: number) => void;
}
export default function SecureFileUpload({
maxFileSize = 100 * 1024 * 1024, // 100MB
allowedTypes = ['.zip', '.tar', '.gz', '.tar.gz'],
onUploadProgress
}: FileUploadProps) {
const validateFile = (file: File): string | null => {
// Size validation
if (file.size > maxFileSize) {
return `File size exceeds ${maxFileSize / 1024 / 1024}MB limit`;
}
// Type validation
const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
if (!allowedTypes.includes(fileExtension)) {
return `File type ${fileExtension} not allowed`;
}
// Name validation
if (!/^[a-zA-Z0-9._-]+$/.test(file.name)) {
return 'File name contains invalid characters';
}
return null;
};
const uploadFile = async (file: File) => {
const validationError = validateFile(file);
if (validationError) {
throw new Error(validationError);
}
const formData = new FormData();
formData.append('file', file);
formData.append('checksum', await calculateChecksum(file));
const response = await fetch('/api/upload', {
method: 'POST',
body: formData,
credentials: 'include',
onUploadProgress: (event) => {
if (event.lengthComputable) {
const progress = (event.loaded / event.total) * 100;
onUploadProgress?.(progress);
}
}
});
if (!response.ok) {
const error = await response.json();
throw new Error(error.message || 'Upload failed');
}
return response.json();
};
const calculateChecksum = async (file: File): Promise<string> => {
const buffer = await file.arrayBuffer();
const hashBuffer = await crypto.subtle.digest('SHA-256', buffer);
const hashArray = Array.from(new Uint8Array(hashBuffer));
return hashArray.map(b => b.toString(16).padStart(2, '0')).join('');
};
}
// Backend: Comprehensive validation
from fastapi import UploadFile, HTTPException
import magic
import hashlib
import zipfile
import tarfile
class FileValidator:
ALLOWED_MIME_TYPES = {
'application/zip',
'application/x-tar',
'application/gzip',
'application/x-gzip'
}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
@staticmethod
async def validate_upload(file: UploadFile, expected_checksum: str = None) -> dict:
# Read file content
content = await file.read()
await file.seek(0) # Reset file pointer
# Size validation
if len(content) > FileValidator.MAX_FILE_SIZE:
raise HTTPException(400, "File too large")
# MIME type validation
mime_type = magic.from_buffer(content, mime=True)
if mime_type not in FileValidator.ALLOWED_MIME_TYPES:
raise HTTPException(400, f"Invalid file type: {mime_type}")
# Checksum validation
if expected_checksum:
actual_checksum = hashlib.sha256(content).hexdigest()
if actual_checksum != expected_checksum:
raise HTTPException(400, "File integrity check failed")
# Archive validation
try:
if mime_type == 'application/zip':
with zipfile.ZipFile(io.BytesIO(content)) as zf:
# Check for zip bombs
total_size = sum(info.file_size for info in zf.infolist())
if total_size > FileValidator.MAX_FILE_SIZE * 10:
raise HTTPException(400, "Archive too large when extracted")
# Check for directory traversal
for info in zf.infolist():
if '..' in info.filename or info.filename.startswith('/'):
raise HTTPException(400, "Invalid file path in archive")
elif mime_type in ['application/x-tar', 'application/gzip']:
with tarfile.open(fileobj=io.BytesIO(content)) as tf:
# Similar validations for tar files
for member in tf.getmembers():
if '..' in member.name or member.name.startswith('/'):
raise HTTPException(400, "Invalid file path in archive")
except Exception as e:
raise HTTPException(400, f"Invalid archive: {str(e)}")
return {
"filename": file.filename,
"size": len(content),
"mime_type": mime_type,
"checksum": hashlib.sha256(content).hexdigest()
}
@router.post("/upload")
async def upload_file(
file: UploadFile,
checksum: str = Form(...),
current_user: User = Depends(get_current_user)
):
# Validate file
validation_result = await FileValidator.validate_upload(file, checksum)
# Virus scanning (integrate with ClamAV or similar)
await scan_for_viruses(file)
# Save file securely
secure_filename = f"{uuid4()}_{file.filename}"
file_path = UPLOAD_DIR / secure_filename
with open(file_path, "wb") as f:
content = await file.read()
f.write(content)
# Log upload
logger.info(f"File uploaded: {file.filename} by user {current_user.id}")
return {
"message": "File uploaded successfully",
"file_id": secure_filename,
**validation_result
}
23. WebSocket Security Issues¶
Location: frontend/src/services/WebSocketManager.ts
Risk: Medium - Unauthorized access
Current Code:
const wsBaseUrl = process.env.NEXT_PUBLIC_WS_URL || "ws://localhost:8000";
return `${wsBaseUrl}/ws/jobs`; // No auth token
Recommendation:
// Frontend: Secure WebSocket with authentication
class SecureWebSocketManager {
private ws: WebSocket | null = null;
private authToken: string | null = null;
private heartbeatInterval: NodeJS.Timeout | null = null;
async connect(): Promise<void> {
// Get auth token from secure storage or API
this.authToken = await this.getAuthToken();
if (!this.authToken) {
throw new Error('Authentication required for WebSocket connection');
}
const wsUrl = this.buildSecureWebSocketUrl();
this.ws = new WebSocket(wsUrl);
this.ws.onopen = this.handleOpen;
this.ws.onmessage = this.handleMessage;
this.ws.onclose = this.handleClose;
this.ws.onerror = this.handleError;
}
private buildSecureWebSocketUrl(): string {
const wsBaseUrl = process.env.NEXT_PUBLIC_WS_URL || "wss://localhost:8000";
const params = new URLSearchParams({
token: this.authToken!,
timestamp: Date.now().toString(),
client_id: this.generateClientId()
});
return `${wsBaseUrl}/ws/jobs?${params.toString()}`;
}
private async getAuthToken(): Promise<string | null> {
try {
// Get token from httpOnly cookie via API call
const response = await fetch('/api/auth/ws-token', {
credentials: 'include'
});
if (response.ok) {
const data = await response.json();
return data.ws_token;
}
} catch (error) {
console.error('Failed to get WebSocket token:', error);
}
return null;
}
private generateClientId(): string {
return `client_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
}
private handleOpen = () => {
console.log('Secure WebSocket connected');
this.startHeartbeat();
// Send authentication confirmation
this.send({
type: 'auth_confirm',
client_id: this.generateClientId(),
timestamp: Date.now()
});
};
private startHeartbeat() {
this.heartbeatInterval = setInterval(() => {
if (this.ws?.readyState === WebSocket.OPEN) {
this.send({ type: 'ping', timestamp: Date.now() });
}
}, 30000);
}
private send(data: any) {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify(data));
}
}
}
// Backend: Secure WebSocket endpoint
from fastapi import WebSocket, WebSocketDisconnect, Query, HTTPException
from jose import jwt, JWTError
import json
class WebSocketManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_connections: Dict[str, Set[str]] = {}
async def connect(self, websocket: WebSocket, client_id: str, user_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
if user_id not in self.user_connections:
self.user_connections[user_id] = set()
self.user_connections[user_id].add(client_id)
logger.info(f"WebSocket connected: {client_id} for user {user_id}")
async def disconnect(self, client_id: str, user_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
if user_id in self.user_connections:
self.user_connections[user_id].discard(client_id)
if not self.user_connections[user_id]:
del self.user_connections[user_id]
logger.info(f"WebSocket disconnected: {client_id}")
websocket_manager = WebSocketManager()
async def verify_websocket_token(token: str) -> dict:
"""Verify WebSocket authentication token."""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Check token type
if payload.get('type') != 'websocket':
raise JWTError("Invalid token type")
# Check expiration
exp = payload.get('exp')
if exp and datetime.utcnow().timestamp() > exp:
raise JWTError("Token expired")
return payload
except JWTError:
raise HTTPException(status_code=401, detail="Invalid WebSocket token")
@app.websocket("/ws/jobs")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(...),
client_id: str = Query(...),
timestamp: str = Query(...)
):
try:
# Verify authentication
payload = await verify_websocket_token(token)
user_id = payload.get('user_id')
if not user_id:
await websocket.close(code=1008, reason="Authentication failed")
return
# Verify timestamp (prevent replay attacks)
try:
request_time = int(timestamp)
current_time = int(datetime.utcnow().timestamp() * 1000)
if abs(current_time - request_time) > 60000: # 1 minute tolerance
await websocket.close(code=1008, reason="Request too old")
return
except ValueError:
await websocket.close(code=1008, reason="Invalid timestamp")
return
# Connect user
await websocket_manager.connect(websocket, client_id, user_id)
try:
while True:
# Receive message
data = await websocket.receive_text()
message = json.loads(data)
# Handle different message types
if message.get('type') == 'ping':
await websocket.send_text(json.dumps({
'type': 'pong',
'timestamp': datetime.utcnow().isoformat()
}))
elif message.get('type') == 'subscribe':
# Handle subscription requests
await handle_subscription(websocket, user_id, message)
except WebSocketDisconnect:
await websocket_manager.disconnect(client_id, user_id)
except Exception as e:
logger.error(f"WebSocket error: {e}")
await websocket_manager.disconnect(client_id, user_id)
except Exception as e:
logger.error(f"WebSocket connection error: {e}")
await websocket.close(code=1008, reason="Connection error")
# Generate WebSocket token endpoint
@router.get("/auth/ws-token")
async def get_websocket_token(current_user: User = Depends(get_current_user)):
"""Generate a short-lived token for WebSocket authentication."""
token_data = {
"user_id": current_user.id,
"username": current_user.username,
"type": "websocket",
"exp": datetime.utcnow() + timedelta(minutes=5) # Short-lived
}
ws_token = jwt.encode(token_data, SECRET_KEY, algorithm=ALGORITHM)
return {"ws_token": ws_token}
Configuration/Environment Issues (Issues 24-27)¶
24. Insecure Development Configuration¶
Location: frontend/next.config.ts
Risk: Medium - Production vulnerabilities
Current Code:
eslint: {
ignoreDuringBuilds: true, // Dangerous for production
},
typescript: {
ignoreBuildErrors: true, // Dangerous for production
},
Recommendation:
const isDevelopment = process.env.NODE_ENV === 'development';
const isProduction = process.env.NODE_ENV === 'production';
const nextConfig = {
// Environment-specific configurations
eslint: {
ignoreDuringBuilds: isDevelopment, // Only ignore in development
dirs: ['src', 'pages', 'components', 'lib', 'utils']
},
typescript: {
ignoreBuildErrors: isDevelopment, // Only ignore in development
tsconfigPath: './tsconfig.json'
},
// Production optimizations
...(isProduction && {
output: 'standalone',
experimental: {
outputFileTracingRoot: path.join(__dirname, '../../'),
},
compiler: {
removeConsole: true, // Remove console.log in production
},
poweredByHeader: false, // Remove X-Powered-By header
generateEtags: false, // Disable ETags for security
}),
// Security headers
async headers() {
return [
{
source: '/(.*)',
headers: [
{
key: 'X-Frame-Options',
value: 'DENY'
},
{
key: 'X-Content-Type-Options',
value: 'nosniff'
},
{
key: 'Referrer-Policy',
value: 'strict-origin-when-cross-origin'
}
]
}
];
},
// Environment variables validation
env: {
NEXT_PUBLIC_API_URL: process.env.NEXT_PUBLIC_API_URL,
NEXT_PUBLIC_WS_URL: process.env.NEXT_PUBLIC_WS_URL,
},
// Webpack configuration
webpack: (config, { dev, isServer }) => {
if (!dev && !isServer) {
// Production client-side optimizations
config.optimization.splitChunks = {
chunks: 'all',
cacheGroups: {
vendor: {
test: /[\\/]node_modules[\\/]/,
name: 'vendors',
chunks: 'all',
},
},
};
}
return config;
},
};
module.exports = nextConfig;
25. Missing Environment Variable Validation¶
Location: Application startup Risk: Medium - Runtime failures
Recommendation:
from pydantic import BaseSettings, validator, Field
from typing import List, Optional
import os
class Settings(BaseSettings):
# Database
mongodb_url: str = Field(..., env="MONGODB_URL")
mongodb_db_name: str = Field("fl_orchestrator_db", env="MONGODB_DB_NAME")
# Security
secret_key: str = Field(..., env="SECRET_KEY")
algorithm: str = Field("HS256", env="ALGORITHM")
access_token_expire_minutes: int = Field(30, env="ACCESS_TOKEN_EXPIRE_MINUTES")
# Environment
environment: str = Field("development", env="NODE_ENV")
debug: bool = Field(False, env="DEBUG")
# CORS
allowed_origins: List[str] = Field(default_factory=list, env="ALLOWED_ORIGINS")
# External Services
orchestrator_ip: str = Field(..., env="ORCHESTRATOR_IP")
frontend_url: str = Field(..., env="FRONTEND_URL")
# OpenTelemetry
otel_exporter_otlp_endpoint: Optional[str] = Field(None, env="OTEL_EXPORTER_OTLP_ENDPOINT")
otel_service_name: str = Field("fl-backend", env="OTEL_SERVICE_NAME")
# Redis (for caching and sessions)
redis_url: Optional[str] = Field(None, env="REDIS_URL")
@validator('secret_key')
def validate_secret_key(cls, v):
if len(v) < 32:
raise ValueError('SECRET_KEY must be at least 32 characters long')
if v == 'your-secret-key-change-in-production':
raise ValueError('SECRET_KEY must be changed from default value')
return v
@validator('environment')
def validate_environment(cls, v):
allowed_envs = ['development', 'staging', 'production']
if v not in allowed_envs:
raise ValueError(f'ENVIRONMENT must be one of: {allowed_envs}')
return v
@validator('mongodb_url')
def validate_mongodb_url(cls, v):
if not v.startswith(('mongodb://', 'mongodb+srv://')):
raise ValueError('MONGODB_URL must be a valid MongoDB connection string')
return v
@validator('allowed_origins')
def validate_allowed_origins(cls, v, values):
if values.get('environment') == 'production' and not v:
raise ValueError('ALLOWED_ORIGINS must be specified in production')
return v
@validator('orchestrator_ip')
def validate_orchestrator_ip(cls, v):
import ipaddress
try:
ipaddress.ip_address(v)
except ValueError:
# Could be hostname, validate format
if not v.replace('.', '').replace('-', '').isalnum():
raise ValueError('ORCHESTRATOR_IP must be valid IP address or hostname')
return v
class Config:
env_file = ".env"
case_sensitive = False
# Validate settings on startup
try:
settings = Settings()
logger.info("Environment validation successful")
logger.info(f"Running in {settings.environment} mode")
except Exception as e:
logger.error(f"Environment validation failed: {e}")
raise SystemExit(1)
# Environment-specific configurations
if settings.environment == "production":
# Production-specific validations
required_prod_vars = [
'SECRET_KEY', 'MONGODB_URL', 'ORCHESTRATOR_IP',
'FRONTEND_URL', 'ALLOWED_ORIGINS'
]
missing_vars = [var for var in required_prod_vars if not getattr(settings, var.lower(), None)]
if missing_vars:
logger.error(f"Missing required production variables: {missing_vars}")
raise SystemExit(1)
# Export settings
__all__ = ["settings"]
26. Hardcoded Database Credentials¶
Location: backend/app/database/mongodb.py
Risk: High - Credential exposure
Current Code:
Recommendation:
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.errors import ServerSelectionTimeoutError
import os
import logging
from typing import Optional
from urllib.parse import quote_plus
logger = logging.getLogger(__name__)
class MongoDB:
client: Optional[AsyncIOMotorClient] = None
db_name: str = None
@classmethod
async def connect_to_mongodb(cls):
# Get credentials from environment (required)
mongodb_url = os.getenv("MONGODB_URL")
if not mongodb_url:
# Build URL from components if not provided as full URL
username = os.getenv("MONGO_DB_USERNAME")
password = os.getenv("MONGO_DB_PASSWORD")
host = os.getenv("MONGO_DB_HOST", "mongodb")
port = os.getenv("MONGO_DB_PORT", "27017")
if not username or not password:
raise ValueError(
"MongoDB credentials must be provided via MONGODB_URL or "
"MONGO_DB_USERNAME/MONGO_DB_PASSWORD environment variables"
)
# URL encode credentials to handle special characters
username_encoded = quote_plus(username)
password_encoded = quote_plus(password)
mongodb_url = f"mongodb://{username_encoded}:{password_encoded}@{host}:{port}/"
cls.db_name = os.getenv("MONGODB_DB_NAME")
if not cls.db_name:
raise ValueError("MONGODB_DB_NAME environment variable must be set")
try:
cls.client = AsyncIOMotorClient(
mongodb_url,
serverSelectionTimeoutMS=5000,
connectTimeoutMS=10000,
socketTimeoutMS=10000,
maxPoolSize=50,
minPoolSize=5,
maxIdleTimeMS=30000,
retryWrites=True,
retryReads=True,
# Security settings
ssl=os.getenv("MONGO_SSL", "false").lower() == "true",
authSource=os.getenv("MONGO_AUTH_SOURCE", "admin"),
authMechanism=os.getenv("MONGO_AUTH_MECHANISM", "SCRAM-SHA-1")
)
# Test connection
await cls.client.admin.command('ping')
# Log connection success (without credentials)
host_info = mongodb_url.split('@')[-1] if '@' in mongodb_url else mongodb_url
logger.info(f"Connected to MongoDB at {host_info}")
except ServerSelectionTimeoutError:
logger.error("Cannot connect to MongoDB - check credentials and network")
cls.client = None
raise
except Exception as e:
logger.error(f"MongoDB connection error: {str(e)}")
cls.client = None
raise
@classmethod
async def close_mongodb_connection(cls):
if cls.client:
cls.client.close()
cls.client = None
logger.info("Closed MongoDB connection")
@classmethod
def get_db(cls):
if cls.client is None:
raise ConnectionError("MongoDB client not initialized")
return cls.client[cls.db_name]
@classmethod
async def health_check(cls) -> dict:
"""Check MongoDB health and return status."""
if cls.client is None:
return {"status": "disconnected", "error": "Client not initialized"}
try:
# Test connection
result = await cls.client.admin.command('ping')
# Get server info
server_info = await cls.client.admin.command('buildInfo')
return {
"status": "connected",
"version": server_info.get("version"),
"database": cls.db_name,
"ping": "ok" if result.get("ok") == 1 else "failed"
}
except Exception as e:
return {"status": "error", "error": str(e)}
27. Overly Permissive CORS¶
Location: backend/app/main.py
Risk: Medium - Cross-origin attacks
Current Code:
cors_config = {
"allow_credentials": True,
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "HEAD"],
"allow_headers": ["*"], # Too permissive
"expose_headers": ["*"], # Too permissive
}
Recommendation:
from fastapi.middleware.cors import CORSMiddleware
from typing import List
import os
def configure_cors(app: FastAPI, settings):
"""Configure CORS with security best practices."""
# Define allowed origins based on environment
if settings.environment == "production":
allowed_origins = settings.allowed_origins
if not allowed_origins:
raise ValueError("ALLOWED_ORIGINS must be specified in production")
else:
# Development origins
allowed_origins = [
"http://localhost:3000",
"http://localhost:4000",
"http://127.0.0.1:3000",
"http://127.0.0.1:4000"
]
# Add configured origins
if settings.allowed_origins:
allowed_origins.extend(settings.allowed_origins)
# Remove duplicates while preserving order
allowed_origins = list(dict.fromkeys(allowed_origins))
# Validate origins format
for origin in allowed_origins:
if not origin.startswith(('http://', 'https://')):
raise ValueError(f"Invalid origin format: {origin}")
# Secure CORS configuration
cors_config = {
"allow_origins": allowed_origins,
"allow_credentials": True,
"allow_methods": [
"GET",
"POST",
"PUT",
"DELETE",
"OPTIONS"
],
"allow_headers": [
"Accept",
"Accept-Language",
"Content-Language",
"Content-Type",
"Authorization",
"X-Requested-With",
"X-CSRF-Token"
],
"expose_headers": [
"Content-Length",
"Content-Type",
"X-Total-Count",
"X-Page-Count"
],
"max_age": 3600 # Cache preflight for 1 hour
}
# Add CORS middleware
app.add_middleware(CORSMiddleware, **cors_config)
# Log CORS configuration (sanitized)
logger.info(f"CORS configured for {len(allowed_origins)} origins")
logger.debug(f"Allowed origins: {allowed_origins}")
return cors_config
# Custom CORS validation middleware
@app.middleware("http")
async def validate_cors_request(request: Request, call_next):
"""Additional CORS validation middleware."""
origin = request.headers.get("origin")
if origin:
# Check if origin is in allowed list
if origin not in settings.allowed_origins and settings.environment == "production":
logger.warning(f"Blocked request from unauthorized origin: {origin}")
return JSONResponse(
status_code=403,
content={"error": "Origin not allowed"}
)
# Check for suspicious patterns
if any(suspicious in origin.lower() for suspicious in ['localhost', '127.0.0.1']) and settings.environment == "production":
logger.warning(f"Blocked localhost request in production: {origin}")
return JSONResponse(
status_code=403,
content={"error": "Localhost not allowed in production"}
)
response = await call_next(request)
return response
# Usage
cors_config = configure_cors(app, settings)
Testing and Quality Issues (Issues 28-35)¶
28. No Test Coverage¶
Risk: High - Undetected bugs
Recommendation:
# pytest configuration - pytest.ini
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
--cov=app
--cov-report=html
--cov-report=term-missing
--cov-fail-under=80
--strict-markers
--disable-warnings
markers =
unit: Unit tests
integration: Integration tests
e2e: End-to-end tests
slow: Slow running tests
# Backend test structure
# tests/
# ├── conftest.py
# ├── unit/
# │ ├── test_auth_service.py
# │ ├── test_user_service.py
# │ └── test_training_service.py
# ├── integration/
# │ ├── test_api_routes.py
# │ └── test_database.py
# └── e2e/
# └── test_training_workflow.py
# tests/conftest.py
import pytest
import asyncio
from httpx import AsyncClient
from motor.motor_asyncio import AsyncIOMotorClient
from app.main import app
from app.database.mongodb import MongoDB
from app.services.auth_service import AuthService
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
async def test_db():
"""Create test database."""
test_client = AsyncIOMotorClient("mongodb://localhost:27017")
test_db = test_client["test_fl_db"]
yield test_db
await test_client.drop_database("test_fl_db")
test_client.close()
@pytest.fixture
async def client():
"""Create test client."""
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
@pytest.fixture
async def test_user(test_db):
"""Create test user."""
user_data = {
"username": "testuser",
"email": "test@example.com",
"password": "TestPassword123!",
"is_active": True
}
user = await UserService.create_user(test_db, user_data)
return user
@pytest.fixture
async def auth_headers(test_user):
"""Create authentication headers."""
token = AuthService.create_access_token(data={"sub": test_user.username})
return {"Authorization": f"Bearer {token}"}
# tests/unit/test_auth_service.py
import pytest
from app.services.auth_service import AuthService, UserService
from app.models.user import UserCreate
class TestAuthService:
def test_password_hashing(self):
"""Test password hashing and verification."""
password = "TestPassword123!"
hashed = AuthService.hash_password(password)
assert hashed != password
assert AuthService.verify_password(password, hashed)
assert not AuthService.verify_password("wrong", hashed)
def test_token_creation_and_verification(self):
"""Test JWT token creation and verification."""
data = {"sub": "testuser", "user_id": "123"}
token = AuthService.create_access_token(data)
assert token is not None
decoded = AuthService.verify_token(token)
assert decoded["sub"] == "testuser"
assert decoded["user_id"] == "123"
def test_invalid_token(self):
"""Test invalid token handling."""
invalid_token = "invalid.token.here"
result = AuthService.verify_token(invalid_token)
assert result is None
@pytest.mark.asyncio
class TestUserService:
async def test_create_user(self, test_db):
"""Test user creation."""
user_data = UserCreate(
username="newuser",
email="new@example.com",
password="NewPassword123!",
confirm_password="NewPassword123!"
)
user = await UserService.create_user(test_db, user_data)
assert user.username == "newuser"
assert user.email == "new@example.com"
assert user.is_active is True
async def test_authenticate_user(self, test_db, test_user):
"""Test user authentication."""
# Valid credentials
user = await UserService.authenticate_user(
test_db, test_user.username, "TestPassword123!"
)
assert user is not None
assert user.username == test_user.username
# Invalid credentials
user = await UserService.authenticate_user(
test_db, test_user.username, "wrongpassword"
)
assert user is None
# tests/integration/test_api_routes.py
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
class TestAuthRoutes:
async def test_register_user(self, client: AsyncClient):
"""Test user registration endpoint."""
user_data = {
"username": "newuser",
"email": "new@example.com",
"password": "NewPassword123!",
"confirm_password": "NewPassword123!"
}
response = await client.post("/auth/register", json=user_data)
assert response.status_code == 201
assert "user_id" in response.json()
async def test_login(self, client: AsyncClient, test_user):
"""Test user login endpoint."""
login_data = {
"username": test_user.username,
"password": "TestPassword123!"
}
response = await client.post("/auth/login", json=login_data)
assert response.status_code == 200
assert "access_token" in response.json()
async def test_protected_route(self, client: AsyncClient, auth_headers):
"""Test protected route access."""
response = await client.get("/auth/me", headers=auth_headers)
assert response.status_code == 200
assert "username" in response.json()
# Frontend testing with Jest and React Testing Library
# package.json test scripts
{
"scripts": {
"test": "jest",
"test:watch": "jest --watch",
"test:coverage": "jest --coverage",
"test:e2e": "playwright test"
}
}
# jest.config.js
module.exports = {
testEnvironment: 'jsdom',
setupFilesAfterEnv: ['<rootDir>/src/test/setup.ts'],
moduleNameMapping: {
'^@/(.*)$': '<rootDir>/src/$1',
},
collectCoverageFrom: [
'src/**/*.{ts,tsx}',
'!src/**/*.d.ts',
'!src/test/**/*',
],
coverageThreshold: {
global: {
branches: 70,
functions: 70,
lines: 70,
statements: 70,
},
},
};
# src/test/setup.ts
import '@testing-library/jest-dom';
import { server } from './mocks/server';
beforeAll(() => server.listen());
afterEach(() => server.resetHandlers());
afterAll(() => server.close());
29. No Input Validation Models¶
Risk: Medium - Data integrity issues
Recommendation:
from pydantic import BaseModel, Field, validator, root_validator
from typing import Optional, List, Dict, Any
from datetime import datetime
from enum import Enum
import re
# Training Job Models
class TrainingJobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TrainingJobCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
project_id: str = Field(..., min_length=1)
config: Dict[str, Any] = Field(default_factory=dict)
@validator('name')
def validate_name(cls, v):
if not re.match(r'^[a-zA-Z0-9_\-\s]+$', v):
raise ValueError('Name can only contain letters, numbers, spaces, hyphens, and underscores')
return v.strip()
@validator('config')
def validate_config(cls, v):
# Validate configuration structure
required_fields = ['rounds', 'min_clients', 'strategy']
for field in required_fields:
if field not in v:
raise ValueError(f'Missing required config field: {field}')
# Validate field types and ranges
if not isinstance(v.get('rounds'), int) or v['rounds'] < 1 or v['rounds'] > 1000:
raise ValueError('rounds must be an integer between 1 and 1000')
if not isinstance(v.get('min_clients'), int) or v['min_clients'] < 1 or v['min_clients'] > 100:
raise ValueError('min_clients must be an integer between 1 and 100')
allowed_strategies = ['FedAvg', 'FedProx', 'FedOpt']
if v.get('strategy') not in allowed_strategies:
raise ValueError(f'strategy must be one of: {allowed_strategies}')
return v
class TrainingJobUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
status: Optional[TrainingJobStatus] = None
config: Optional[Dict[str, Any]] = None
@validator('name')
def validate_name(cls, v):
if v is not None:
if not re.match(r'^[a-zA-Z0-9_\-\s]+$', v):
raise ValueError('Name can only contain letters, numbers, spaces, hyphens, and underscores')
return v.strip()
return v
# Project Models
class ProjectCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=1000)
ml_framework: str = Field(..., regex=r'^(tensorflow|pytorch|sklearn)$')
dataset_type: str = Field(..., regex=r'^(image|text|tabular|time_series)$')
@validator('name')
def validate_name(cls, v):
if not re.match(r'^[a-zA-Z0-9_\-\s]+$', v):
raise ValueError('Project name can only contain letters, numbers, spaces, hyphens, and underscores')
return v.strip()
# Ansible Job Models
class AnsibleCommand(str, Enum):
DEPLOY = "deploy"
START = "start"
STOP = "stop"
STATUS = "status"
COPY_DATASET = "copy_dataset"
class AnsibleJobCreate(BaseModel):
command: AnsibleCommand
inventory: str = Field(..., min_length=1)
extra_vars: Optional[Dict[str, Any]] = Field(default_factory=dict)
target_hosts: Optional[List[str]] = Field(default_factory=list)
@validator('inventory')
def validate_inventory(cls, v):
allowed_inventories = ['aggregator', 'client', 'all']
if v not in allowed_inventories:
raise ValueError(f'inventory must be one of: {allowed_inventories}')
return v
@validator('target_hosts')
def validate_target_hosts(cls, v):
if v:
for host in v:
# Basic hostname/IP validation
if not re.match(r'^[a-zA-Z0-9\-\.]+$', host):
raise ValueError(f'Invalid hostname: {host}')
return v
# File Upload Models
class FileUploadResponse(BaseModel):
filename: str
size: int
mime_type: str
checksum: str
upload_id: str
created_at: datetime
# Query Parameter Models
class PaginationParams(BaseModel):
page: int = Field(1, ge=1, le=1000)
limit: int = Field(20, ge=1, le=100)
@property
def skip(self) -> int:
return (self.page - 1) * self.limit
class TrainingJobFilters(BaseModel):
status: Optional[TrainingJobStatus] = None
project_id: Optional[str] = None
created_after: Optional[datetime] = None
created_before: Optional[datetime] = None
@root_validator
def validate_date_range(cls, values):
created_after = values.get('created_after')
created_before = values.get('created_before')
if created_after and created_before:
if created_after >= created_before:
raise ValueError('created_after must be before created_before')
return values
# WebSocket Message Models
class WebSocketMessage(BaseModel):
type: str = Field(..., regex=r'^[a-z_]+$')
data: Dict[str, Any] = Field(default_factory=dict)
timestamp: datetime = Field(default_factory=datetime.utcnow)
class JobUpdateMessage(WebSocketMessage):
type: str = Field("job_update", const=True)
job_id: str = Field(..., min_length=1)
status: TrainingJobStatus
progress: Optional[float] = Field(None, ge=0.0, le=1.0)
# Response Models
class PaginatedResponse(BaseModel):
items: List[Any]
total: int = Field(..., ge=0)
page: int = Field(..., ge=1)
pages: int = Field(..., ge=0)
has_next: bool
has_prev: bool
class ErrorResponse(BaseModel):
error: str
code: Optional[str] = None
details: Optional[Dict[str, Any]] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)
# Usage in routes with validation
@router.post("/training/jobs", response_model=TrainingJobResponse)
async def create_training_job(
job_data: TrainingJobCreate, # Automatic validation
current_user: User = Depends(get_current_user),
db = Depends(get_database)
):
# job_data is already validated by Pydantic
job = await TrainingService.create_job(db, job_data, current_user.id)
return job
@router.get("/training/jobs", response_model=PaginatedResponse)
async def get_training_jobs(
pagination: PaginationParams = Depends(), # Query param validation
filters: TrainingJobFilters = Depends(), # Filter validation
current_user: User = Depends(get_current_user),
db = Depends(get_database)
):
jobs = await TrainingService.get_user_jobs(
db, current_user.id, pagination, filters
)
return jobs
30. Missing API Documentation¶
Risk: Low - Developer productivity
Recommendation:
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.openapi.utils import get_openapi
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
# Enhanced FastAPI app with comprehensive documentation
app = FastAPI(
title="Federated Learning Platform API",
description="""
A comprehensive API for federated learning orchestration with authentication.
## Features
* **Authentication**: JWT-based authentication with role-based access control
* **Training Management**: Create, monitor, and manage federated learning jobs
* **Project Management**: Organize ML projects and configurations
* **Ansible Integration**: Automated deployment and orchestration
* **Real-time Updates**: WebSocket-based live monitoring
## Authentication
Most endpoints require authentication. To authenticate:
1. Register a new user account via `/auth/register`
2. Login via `/auth/login` to receive an access token
3. Include the token in the Authorization header: `Bearer <token>`
## Rate Limiting
API endpoints are rate limited to prevent abuse:
* Authentication endpoints: 5 requests per minute
* General endpoints: 100 requests per minute
* File uploads: 10 requests per minute
## Error Handling
The API uses standard HTTP status codes and returns errors in JSON format:
```json
{
"error": "Error description",
"code": "ERROR_CODE",
"timestamp": "2024-01-01T00:00:00Z"
}
```
""",
version="1.2.0",
contact={
"name": "Federated Learning Platform Team",
"email": "support@fl-platform.com",
"url": "https://fl-platform.com/support"
},
license_info={
"name": "MIT License",
"url": "https://opensource.org/licenses/MIT"
},
servers=[
{
"url": "http://localhost:8000",
"description": "Development server"
},
{
"url": "https://api.fl-platform.com",
"description": "Production server"
}
]
)
# Custom OpenAPI schema
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# Add security schemes
openapi_schema["components"]["securitySchemes"] = {
"BearerAuth": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "JWT token obtained from /auth/login"
}
}
# Add global security requirement
openapi_schema["security"] = [{"BearerAuth": []}]
# Add custom tags
openapi_schema["tags"] = [
{
"name": "authentication",
"description": "User authentication and authorization"
},
{
"name": "training",
"description": "Federated learning training management"
},
{
"name": "projects",
"description": "ML project and configuration management"
},
{
"name": "ansible",
"description": "Automated deployment and orchestration"
},
{
"name": "monitoring",
"description": "System monitoring and health checks"
}
]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
# Enhanced route documentation
@router.post(
"/training/jobs",
response_model=TrainingJobResponse,
status_code=status.HTTP_201_CREATED,
summary="Create a new training job",
description="""
Create a new federated learning training job.
The training job will be queued and executed based on the provided configuration.
You can monitor the job progress via WebSocket connections or by polling the job status endpoint.
**Required permissions**: User must be authenticated and have access to the specified project.
""",
responses={
201: {
"description": "Training job created successfully",
"content": {
"application/json": {
"example": {
"id": "job_123",
"name": "MNIST Classification",
"status": "pending",
"created_at": "2024-01-01T00:00:00Z"
}
}
}
},
400: {
"description": "Invalid request data",
"content": {
"application/json": {
"example": {
"error": "Invalid configuration",
"code": "INVALID_CONFIG"
}
}
}
},
401: {"description": "Authentication required"},
403: {"description": "Insufficient permissions"},
422: {"description": "Validation error"}
},
tags=["training"]
)
async def create_training_job(
job_data: TrainingJobCreate = Body(
...,
example={
"name": "MNIST Classification",
"description": "Federated learning on MNIST dataset",
"project_id": "proj_123",
"config": {
"rounds": 10,
"min_clients": 2,
"strategy": "FedAvg"
}
}
),
current_user: User = Depends(get_current_user),
db = Depends(get_database)
):
"""Create a new federated learning training job."""
job = await TrainingService.create_job(db, job_data, current_user.id)
return job
# Custom documentation pages
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
return get_swagger_ui_html(
openapi_url=app.openapi_url,
title=f"{app.title} - Swagger UI",
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
swagger_js_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
swagger_css_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
swagger_ui_parameters={
"deepLinking": True,
"displayRequestDuration": True,
"docExpansion": "none",
"operationsSorter": "alpha",
"filter": True,
"showExtensions": True,
"showCommonExtensions": True
}
)
@app.get("/redoc", include_in_schema=False)
async def redoc_html():
return get_redoc_html(
openapi_url=app.openapi_url,
title=f"{app.title} - ReDoc",
redoc_js_url="https://cdn.jsdelivr.net/npm/redoc@2.0.0/bundles/redoc.standalone.js",
)
31. Code Quality Issues¶
Risk: Medium - Maintainability
Recommendation:
# Backend: Python code quality tools
# pyproject.toml
[tool.black]
line-length = 88
target-version = ['py310']
include = '\.pyi?$'
extend-exclude = '''
/(
# directories
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| build
| dist
)/
'''
[tool.isort]
profile = "black"
multi_line_output = 3
line_length = 88
known_first_party = ["app"]
[tool.flake8]
max-line-length = 88
extend-ignore = ["E203", "W503"]
exclude = [".git", "__pycache__", "build", "dist"]
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true
strict_equality = true
# Pre-commit configuration - .pre-commit-config.yaml
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
# Makefile for development tasks
.PHONY: format lint test security-check
format:
black app tests
isort app tests
lint:
flake8 app tests
mypy app
pylint app
test:
pytest tests/ -v --cov=app --cov-report=html
security-check:
bandit -r app/
safety check
quality-check: format lint test security-check
@echo "All quality checks passed!"
32. Missing Type Safety¶
Risk: Medium - Runtime errors
Recommendation:
// Enhanced TypeScript configuration
// tsconfig.json
{
"compilerOptions": {
"target": "ES2022",
"lib": ["dom", "dom.iterable", "ES6"],
"strict": true,
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"incremental": true,
"baseUrl": ".",
"paths": {
"@/*": ["./src/*"]
},
// Strict type checking
"noImplicitAny": true,
"noImplicitReturns": true,
"noImplicitThis": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"exactOptionalPropertyTypes": true,
"noUncheckedIndexedAccess": true,
"noImplicitOverride": true
}
}
// Enhanced type definitions
export interface User {
readonly id: string;
readonly username: string;
readonly email: string | null;
readonly isActive: boolean;
readonly roles: readonly string[];
readonly createdAt: string;
readonly updatedAt: string;
}
export interface TrainingJob {
readonly id: string;
readonly name: string;
readonly status: TrainingJobStatus;
readonly projectId: string;
readonly config: TrainingConfig;
readonly createdAt: string;
}
export type TrainingJobStatus =
| 'pending'
| 'running'
| 'completed'
| 'failed'
| 'cancelled';
// Type guards
export function isTrainingJob(obj: unknown): obj is TrainingJob {
return (
typeof obj === 'object' &&
obj !== null &&
'id' in obj &&
'name' in obj &&
'status' in obj &&
typeof (obj as any).id === 'string'
);
}
33. No Integration Tests¶
Risk: High - System failures
Recommendation:
# Integration test setup
# tests/integration/conftest.py
import pytest
import asyncio
from httpx import AsyncClient
from testcontainers.mongodb import MongoDbContainer
@pytest.fixture(scope="session")
async def mongodb_container():
"""Start MongoDB container for integration tests."""
with MongoDbContainer("mongo:5.0") as mongodb:
yield mongodb
@pytest.fixture
async def client(test_app):
"""Create test client."""
async with AsyncClient(app=test_app, base_url="http://test") as ac:
yield ac
# tests/integration/test_training_workflow.py
@pytest.mark.asyncio
class TestTrainingWorkflow:
async def test_complete_training_workflow(self, client: AsyncClient):
"""Test end-to-end training workflow."""
# 1. Register user
user_data = {
"username": "testuser",
"email": "test@example.com",
"password": "TestPassword123!",
"confirm_password": "TestPassword123!"
}
register_response = await client.post("/auth/register", json=user_data)
assert register_response.status_code == 201
# 2. Login
login_response = await client.post("/auth/login", json={
"username": "testuser",
"password": "TestPassword123!"
})
assert login_response.status_code == 200
token = login_response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}"}
# 3. Create training job
job_data = {
"name": "Test Training Job",
"project_id": "test-project",
"config": {
"rounds": 5,
"min_clients": 2,
"strategy": "FedAvg"
}
}
job_response = await client.post("/training/jobs", json=job_data, headers=headers)
assert job_response.status_code == 201
job_id = job_response.json()["id"]
# 4. Check job status
status_response = await client.get(f"/training/jobs/{job_id}", headers=headers)
assert status_response.status_code == 200
assert status_response.json()["status"] == "pending"
Operational/Monitoring Issues (Issues 36-41)¶
36. Insufficient Logging and Monitoring¶
Location: Application-wide Risk: High - Operational blindness
Recommendation:
import logging
import structlog
from pythonjsonlogger import jsonlogger
from opentelemetry import trace, metrics
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
# Structured logging configuration
def configure_logging():
"""Configure structured logging with JSON output."""
# Configure structlog
structlog.configure(
processors=[
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
structlog.processors.JSONRenderer()
],
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)
# Configure standard logging
formatter = jsonlogger.JsonFormatter(
'%(asctime)s %(name)s %(levelname)s %(message)s'
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)
# Application metrics
class ApplicationMetrics:
def __init__(self):
self.meter = metrics.get_meter(__name__)
# Request metrics
self.request_counter = self.meter.create_counter(
"http_requests_total",
description="Total HTTP requests"
)
self.request_duration = self.meter.create_histogram(
"http_request_duration_seconds",
description="HTTP request duration"
)
# Training job metrics
self.training_jobs_total = self.meter.create_counter(
"training_jobs_total",
description="Total training jobs created"
)
self.training_jobs_active = self.meter.create_up_down_counter(
"training_jobs_active",
description="Currently active training jobs"
)
# Database metrics
self.db_connections = self.meter.create_up_down_counter(
"database_connections_active",
description="Active database connections"
)
self.db_query_duration = self.meter.create_histogram(
"database_query_duration_seconds",
description="Database query duration"
)
# Middleware for request logging and metrics
@app.middleware("http")
async def logging_middleware(request: Request, call_next):
start_time = time.time()
# Create request context
request_id = str(uuid4())
# Structured logging context
logger = structlog.get_logger().bind(
request_id=request_id,
method=request.method,
url=str(request.url),
user_agent=request.headers.get("user-agent"),
remote_addr=request.client.host if request.client else None
)
logger.info("Request started")
try:
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
"Request completed",
status_code=response.status_code,
duration=duration
)
# Record metrics
metrics.request_counter.add(1, {
"method": request.method,
"status": str(response.status_code),
"endpoint": request.url.path
})
metrics.request_duration.record(duration, {
"method": request.method,
"endpoint": request.url.path
})
return response
except Exception as e:
duration = time.time() - start_time
logger.error(
"Request failed",
error=str(e),
duration=duration,
exc_info=True
)
# Record error metrics
metrics.request_counter.add(1, {
"method": request.method,
"status": "500",
"endpoint": request.url.path
})
raise
# Health check endpoint with detailed monitoring
@router.get("/health/detailed")
async def detailed_health_check():
"""Comprehensive health check with system metrics."""
health_status = {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"version": "1.0.0",
"checks": {}
}
# Database health
try:
db_health = await MongoDB.health_check()
health_status["checks"]["database"] = db_health
except Exception as e:
health_status["checks"]["database"] = {
"status": "unhealthy",
"error": str(e)
}
health_status["status"] = "degraded"
# Redis health (if configured)
if redis_client:
try:
await redis_client.ping()
health_status["checks"]["redis"] = {"status": "healthy"}
except Exception as e:
health_status["checks"]["redis"] = {
"status": "unhealthy",
"error": str(e)
}
health_status["status"] = "degraded"
# System metrics
import psutil
health_status["system"] = {
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent,
"disk_percent": psutil.disk_usage('/').percent,
"load_average": psutil.getloadavg() if hasattr(psutil, 'getloadavg') else None
}
# Application metrics
health_status["application"] = {
"active_training_jobs": await get_active_training_jobs_count(),
"total_users": await get_total_users_count(),
"uptime_seconds": time.time() - app_start_time
}
return health_status
37. No Health Checks¶
Location: Application monitoring Risk: Medium - Service availability issues
Recommendation:
# Comprehensive health check system
@router.get("/health")
async def basic_health_check():
"""Basic health check for load balancers."""
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
@router.get("/health/ready")
async def readiness_check():
"""Readiness check for Kubernetes."""
try:
# Check database connectivity
await MongoDB.get_db().command("ping")
# Check critical services
checks = {
"database": "ok",
"redis": "ok" if await check_redis() else "fail"
}
if all(status == "ok" for status in checks.values()):
return {"status": "ready", "checks": checks}
else:
raise HTTPException(503, {"status": "not_ready", "checks": checks})
except Exception as e:
raise HTTPException(503, {"status": "not_ready", "error": str(e)})
@router.get("/health/live")
async def liveness_check():
"""Liveness check for Kubernetes."""
# Simple check that the application is running
return {"status": "alive", "timestamp": datetime.utcnow().isoformat()}
38. Missing Metrics Collection¶
Location: Application monitoring Risk: Medium - Performance visibility
Recommendation:
# Prometheus metrics integration
from prometheus_client import Counter, Histogram, Gauge, generate_latest
# Define metrics
REQUEST_COUNT = Counter('http_requests_total', 'Total HTTP requests', ['method', 'endpoint', 'status'])
REQUEST_DURATION = Histogram('http_request_duration_seconds', 'HTTP request duration')
ACTIVE_CONNECTIONS = Gauge('active_connections', 'Active database connections')
TRAINING_JOBS = Gauge('training_jobs_active', 'Active training jobs')
@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
# Record metrics
REQUEST_COUNT.labels(
method=request.method,
endpoint=request.url.path,
status=response.status_code
).inc()
REQUEST_DURATION.observe(time.time() - start_time)
return response
@router.get("/metrics")
async def metrics_endpoint():
"""Prometheus metrics endpoint."""
return Response(generate_latest(), media_type="text/plain")
39. No Alerting System¶
Location: Operational monitoring Risk: High - Incident response delays
Recommendation:
# Alert manager integration
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
class AlertManager:
def __init__(self):
self.smtp_server = os.getenv("SMTP_SERVER")
self.smtp_port = int(os.getenv("SMTP_PORT", "587"))
self.smtp_username = os.getenv("SMTP_USERNAME")
self.smtp_password = os.getenv("SMTP_PASSWORD")
self.alert_recipients = os.getenv("ALERT_RECIPIENTS", "").split(",")
async def send_alert(self, severity: str, title: str, message: str, details: dict = None):
"""Send alert notification."""
# Log alert
logger.error(f"ALERT [{severity}]: {title}", extra={
"alert_severity": severity,
"alert_title": title,
"alert_message": message,
"alert_details": details
})
# Send email notification
if self.smtp_server and self.alert_recipients:
await self._send_email_alert(severity, title, message, details)
# Send to external monitoring (PagerDuty, Slack, etc.)
await self._send_external_alert(severity, title, message, details)
async def _send_email_alert(self, severity: str, title: str, message: str, details: dict):
"""Send email alert."""
try:
msg = MIMEMultipart()
msg['From'] = self.smtp_username
msg['To'] = ", ".join(self.alert_recipients)
msg['Subject'] = f"[{severity.upper()}] FL Platform Alert: {title}"
body = f"""
Alert: {title}
Severity: {severity}
Message: {message}
Time: {datetime.utcnow().isoformat()}
Details:
{json.dumps(details, indent=2) if details else 'None'}
"""
msg.attach(MIMEText(body, 'plain'))
server = smtplib.SMTP(self.smtp_server, self.smtp_port)
server.starttls()
server.login(self.smtp_username, self.smtp_password)
server.send_message(msg)
server.quit()
except Exception as e:
logger.error(f"Failed to send email alert: {e}")
# Alert conditions
async def check_system_health():
"""Check system health and trigger alerts."""
# Check database connectivity
try:
await MongoDB.get_db().command("ping")
except Exception as e:
await alert_manager.send_alert(
"critical",
"Database Connection Failed",
f"Cannot connect to MongoDB: {str(e)}"
)
# Check system resources
import psutil
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
disk_percent = psutil.disk_usage('/').percent
if cpu_percent > 90:
await alert_manager.send_alert(
"warning",
"High CPU Usage",
f"CPU usage is {cpu_percent}%"
)
if memory_percent > 90:
await alert_manager.send_alert(
"warning",
"High Memory Usage",
f"Memory usage is {memory_percent}%"
)
if disk_percent > 90:
await alert_manager.send_alert(
"critical",
"Low Disk Space",
f"Disk usage is {disk_percent}%"
)
# Scheduled health checks
from apscheduler.schedulers.asyncio import AsyncIOScheduler
scheduler = AsyncIOScheduler()
scheduler.add_job(check_system_health, 'interval', minutes=5)
scheduler.start()
40. Inadequate Backup Procedures¶
Location: Data management Risk: High - Data loss
Recommendation:
# Automated backup system
import subprocess
import boto3
from datetime import datetime, timedelta
class BackupManager:
def __init__(self):
self.s3_client = boto3.client('s3')
self.backup_bucket = os.getenv("BACKUP_S3_BUCKET")
self.mongodb_url = os.getenv("MONGODB_URL")
self.backup_retention_days = int(os.getenv("BACKUP_RETENTION_DAYS", "30"))
async def create_database_backup(self):
"""Create MongoDB backup."""
try:
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
backup_filename = f"mongodb_backup_{timestamp}.gz"
# Create mongodump
dump_command = [
"mongodump",
"--uri", self.mongodb_url,
"--gzip",
"--archive", backup_filename
]
result = subprocess.run(dump_command, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"Mongodump failed: {result.stderr}")
# Upload to S3
if self.backup_bucket:
s3_key = f"database_backups/{backup_filename}"
self.s3_client.upload_file(backup_filename, self.backup_bucket, s3_key)
# Clean up local file
os.remove(backup_filename)
logger.info(f"Database backup uploaded to S3: {s3_key}")
return backup_filename
except Exception as e:
logger.error(f"Database backup failed: {e}")
await alert_manager.send_alert(
"critical",
"Database Backup Failed",
f"Failed to create database backup: {str(e)}"
)
raise
async def cleanup_old_backups(self):
"""Remove old backups based on retention policy."""
try:
cutoff_date = datetime.utcnow() - timedelta(days=self.backup_retention_days)
# List objects in backup folder
response = self.s3_client.list_objects_v2(
Bucket=self.backup_bucket,
Prefix="database_backups/"
)
deleted_count = 0
for obj in response.get('Contents', []):
if obj['LastModified'].replace(tzinfo=None) < cutoff_date:
self.s3_client.delete_object(
Bucket=self.backup_bucket,
Key=obj['Key']
)
deleted_count += 1
logger.info(f"Cleaned up {deleted_count} old backups")
except Exception as e:
logger.error(f"Backup cleanup failed: {e}")
async def restore_database_backup(self, backup_filename: str):
"""Restore database from backup."""
try:
# Download from S3
s3_key = f"database_backups/{backup_filename}"
self.s3_client.download_file(self.backup_bucket, s3_key, backup_filename)
# Restore using mongorestore
restore_command = [
"mongorestore",
"--uri", self.mongodb_url,
"--gzip",
"--archive", backup_filename,
"--drop" # Drop existing collections
]
result = subprocess.run(restore_command, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"Mongorestore failed: {result.stderr}")
# Clean up local file
os.remove(backup_filename)
logger.info(f"Database restored from backup: {backup_filename}")
except Exception as e:
logger.error(f"Database restore failed: {e}")
raise
# Scheduled backups
backup_manager = BackupManager()
# Daily backup at 2 AM
scheduler.add_job(
backup_manager.create_database_backup,
'cron',
hour=2,
minute=0
)
# Weekly cleanup on Sundays
scheduler.add_job(
backup_manager.cleanup_old_backups,
'cron',
day_of_week=6,
hour=3,
minute=0
)
41. Missing Disaster Recovery¶
Location: Infrastructure Risk: Critical - Business continuity
Recommendation:
# Disaster Recovery Plan Documentation
# docs/disaster-recovery.md
## Disaster Recovery Procedures
### 1. Database Recovery
```bash
# Restore from latest backup
python -c "
from app.backup_manager import BackupManager
import asyncio
async def restore():
backup_manager = BackupManager()
# Get latest backup
latest_backup = await backup_manager.get_latest_backup()
await backup_manager.restore_database_backup(latest_backup)
asyncio.run(restore())
"
2. Application Recovery¶
# Deploy to backup infrastructure
kubectl apply -f k8s/disaster-recovery/
kubectl rollout status deployment/fl-backend-dr
kubectl rollout status deployment/fl-frontend-dr
3. Data Center Failover¶
# Switch DNS to backup data center
aws route53 change-resource-record-sets \
--hosted-zone-id Z123456789 \
--change-batch file://dns-failover.json
Kubernetes disaster recovery configuration¶
k8s/disaster-recovery/deployment.yaml¶
apiVersion: apps/v1 kind: Deployment metadata: name: fl-backend-dr namespace: fl-platform-dr spec: replicas: 2 selector: matchLabels: app: fl-backend-dr template: metadata: labels: app: fl-backend-dr spec: containers: - name: backend image: fl-platform/backend:latest env: - name: MONGODB_URL valueFrom: secretKeyRef: name: mongodb-dr-secret key: connection-string - name: ENVIRONMENT value: "disaster-recovery" resources: requests: memory: "512Mi" cpu: "250m" limits: memory: "1Gi" cpu: "500m" livenessProbe: httpGet: path: /health/live port: 8000 initialDelaySeconds: 30 periodSeconds: 10 readinessProbe: httpGet: path: /health/ready port: 8000 initialDelaySeconds: 5 periodSeconds: 5
## Docker/Deployment Issues (Issues 42-46)
### 42. Insecure Docker Configuration
**Location**: `docker-compose.yml`, Dockerfiles
**Risk**: High - Container security vulnerabilities
**Recommendation**:
```dockerfile
# Secure Dockerfile for backend
FROM python:3.10-slim as builder
# Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser
# Install system dependencies
RUN apt-get update && apt-get install -y \
build-essential \
curl \
&& rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /app
# Copy requirements and install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir --user -r requirements.txt
# Production stage
FROM python:3.10-slim
# Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser
# Install runtime dependencies only
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
# Copy installed packages from builder
COPY --from=builder /root/.local /home/appuser/.local
# Set working directory
WORKDIR /app
# Copy application code
COPY --chown=appuser:appuser . .
# Switch to non-root user
USER appuser
# Add local bin to PATH
ENV PATH=/home/appuser/.local/bin:$PATH
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Expose port
EXPOSE 8000
# Run application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
# Secure docker-compose.yml
version: '3.8'
services:
backend:
build:
context: ./backend
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
- MONGODB_URL=${MONGODB_URL}
- SECRET_KEY=${SECRET_KEY}
- ENVIRONMENT=production
volumes:
- ./logs:/app/logs:rw
networks:
- fl-network
restart: unless-stopped
security_opt:
- no-new-privileges:true
cap_drop:
- ALL
cap_add:
- NET_BIND_SERVICE
read_only: true
tmpfs:
- /tmp
- /var/tmp
user: "1000:1000"
mongodb:
image: mongo:5.0
environment:
- MONGO_INITDB_ROOT_USERNAME=${MONGO_ROOT_USERNAME}
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_ROOT_PASSWORD}
volumes:
- mongodb_data:/data/db:rw
- ./mongo-init:/docker-entrypoint-initdb.d:ro
networks:
- fl-network
restart: unless-stopped
security_opt:
- no-new-privileges:true
user: "999:999"
networks:
fl-network:
driver: bridge
ipam:
config:
- subnet: 172.20.0.0/16
volumes:
mongodb_data:
driver: local
43. Missing Production Docker Compose¶
Location: Deployment configuration Risk: Medium - Production deployment issues
Recommendation:
# docker-compose.prod.yml
version: '3.8'
services:
backend:
image: fl-platform/backend:${VERSION:-latest}
ports:
- "8000:8000"
environment:
- MONGODB_URL=${MONGODB_URL}
- SECRET_KEY=${SECRET_KEY}
- ENVIRONMENT=production
- REDIS_URL=${REDIS_URL}
- OTEL_EXPORTER_OTLP_ENDPOINT=${OTEL_ENDPOINT}
volumes:
- ./logs:/app/logs:rw
- ./uploads:/app/uploads:rw
networks:
- fl-network
restart: unless-stopped
deploy:
replicas: 3
resources:
limits:
cpus: '1.0'
memory: 1G
reservations:
cpus: '0.5'
memory: 512M
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 40s
logging:
driver: "json-file"
options:
max-size: "10m"
max-file: "3"
frontend:
image: fl-platform/frontend:${VERSION:-latest}
ports:
- "3000:3000"
environment:
- NEXT_PUBLIC_API_URL=${API_URL}
- NODE_ENV=production
networks:
- fl-network
restart: unless-stopped
deploy:
replicas: 2
resources:
limits:
cpus: '0.5'
memory: 512M
nginx:
image: nginx:alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro
- ./nginx/ssl:/etc/nginx/ssl:ro
networks:
- fl-network
restart: unless-stopped
depends_on:
- backend
- frontend
redis:
image: redis:7-alpine
command: redis-server --requirepass ${REDIS_PASSWORD}
volumes:
- redis_data:/data
networks:
- fl-network
restart: unless-stopped
networks:
fl-network:
driver: overlay
attachable: true
volumes:
redis_data:
driver: local
44. No Container Security Scanning¶
Location: CI/CD pipeline Risk: High - Vulnerable container images
Recommendation:
# .github/workflows/security-scan.yml
name: Container Security Scan
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
container-scan:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Build Docker image
run: |
docker build -t fl-backend:test ./backend
docker build -t fl-frontend:test ./frontend
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: 'fl-backend:test'
format: 'sarif'
output: 'trivy-results.sarif'
- name: Upload Trivy scan results
uses: github/codeql-action/upload-sarif@v2
with:
sarif_file: 'trivy-results.sarif'
- name: Run Snyk container scan
uses: snyk/actions/docker@master
env:
SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }}
with:
image: fl-backend:test
args: --severity-threshold=high
- name: Run Docker Bench Security
run: |
docker run --rm --net host --pid host --userns host --cap-add audit_control \
-e DOCKER_CONTENT_TRUST=$DOCKER_CONTENT_TRUST \
-v /etc:/etc:ro \
-v /usr/bin/containerd:/usr/bin/containerd:ro \
-v /usr/bin/runc:/usr/bin/runc:ro \
-v /usr/lib/systemd:/usr/lib/systemd:ro \
-v /var/lib:/var/lib:ro \
-v /var/run/docker.sock:/var/run/docker.sock:ro \
--label docker_bench_security \
docker/docker-bench-security
45. Missing Deployment Automation¶
Location: CI/CD pipeline Risk: Medium - Manual deployment errors
Recommendation:
# .github/workflows/deploy.yml
name: Deploy to Production
on:
push:
tags:
- 'v*'
jobs:
deploy:
runs-on: ubuntu-latest
environment: production
steps:
- uses: actions/checkout@v3
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v2
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: us-west-2
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v1
- name: Build and push Docker images
env:
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
ECR_REPOSITORY: fl-platform
IMAGE_TAG: ${{ github.ref_name }}
run: |
# Build and push backend
docker build -t $ECR_REGISTRY/$ECR_REPOSITORY/backend:$IMAGE_TAG ./backend
docker push $ECR_REGISTRY/$ECR_REPOSITORY/backend:$IMAGE_TAG
# Build and push frontend
docker build -t $ECR_REGISTRY/$ECR_REPOSITORY/frontend:$IMAGE_TAG ./frontend
docker push $ECR_REGISTRY/$ECR_REPOSITORY/frontend:$IMAGE_TAG
- name: Deploy to EKS
run: |
aws eks update-kubeconfig --name fl-platform-prod
# Update image tags in deployment
kubectl set image deployment/fl-backend \
backend=$ECR_REGISTRY/$ECR_REPOSITORY/backend:$IMAGE_TAG
kubectl set image deployment/fl-frontend \
frontend=$ECR_REGISTRY/$ECR_REPOSITORY/frontend:$IMAGE_TAG
# Wait for rollout
kubectl rollout status deployment/fl-backend
kubectl rollout status deployment/fl-frontend
- name: Run smoke tests
run: |
# Wait for services to be ready
kubectl wait --for=condition=ready pod -l app=fl-backend --timeout=300s
# Run basic health checks
BACKEND_URL=$(kubectl get service fl-backend -o jsonpath='{.status.loadBalancer.ingress[0].hostname}')
curl -f http://$BACKEND_URL/health || exit 1
- name: Notify deployment
uses: 8398a7/action-slack@v3
with:
status: ${{ job.status }}
channel: '#deployments'
webhook_url: ${{ secrets.SLACK_WEBHOOK }}
46. No Rollback Procedures¶
Location: Deployment process Risk: High - Failed deployment recovery
Recommendation:
#!/bin/bash
# scripts/rollback.sh
set -e
NAMESPACE=${NAMESPACE:-fl-platform}
DEPLOYMENT=${1:-fl-backend}
REVISION=${2:-}
echo "Rolling back deployment: $DEPLOYMENT"
if [ -n "$REVISION" ]; then
echo "Rolling back to revision: $REVISION"
kubectl rollout undo deployment/$DEPLOYMENT --to-revision=$REVISION -n $NAMESPACE
else
echo "Rolling back to previous revision"
kubectl rollout undo deployment/$DEPLOYMENT -n $NAMESPACE
fi
echo "Waiting for rollback to complete..."
kubectl rollout status deployment/$DEPLOYMENT -n $NAMESPACE --timeout=300s
echo "Verifying rollback..."
kubectl get pods -l app=$DEPLOYMENT -n $NAMESPACE
# Run health checks
echo "Running health checks..."
BACKEND_URL=$(kubectl get service $DEPLOYMENT -n $NAMESPACE -o jsonpath='{.status.loadBalancer.ingress[0].hostname}')
for i in {1..10}; do
if curl -f http://$BACKEND_URL/health; then
echo "Health check passed"
break
else
echo "Health check failed, attempt $i/10"
sleep 10
fi
done
echo "Rollback completed successfully"
Code Quality Issues (Issues 47-50)¶
47. Inconsistent Code Style¶
Location: Codebase-wide Risk: Low - Developer productivity
Recommendation:
// .eslintrc.json
{
"extends": [
"next/core-web-vitals",
"@typescript-eslint/recommended",
"prettier"
],
"parser": "@typescript-eslint/parser",
"plugins": ["@typescript-eslint"],
"rules": {
"@typescript-eslint/no-unused-vars": "error",
"@typescript-eslint/no-explicit-any": "warn",
"@typescript-eslint/explicit-function-return-type": "warn",
"prefer-const": "error",
"no-var": "error"
}
}
// prettier.config.js
module.exports = {
semi: true,
trailingComma: 'es5',
singleQuote: true,
printWidth: 80,
tabWidth: 2,
useTabs: false
};
48. Missing Documentation Comments¶
Location: Code functions and classes Risk: Low - Code maintainability
Recommendation:
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
class TrainingService:
"""Service for managing federated learning training jobs.
This service handles the creation, monitoring, and management of
federated learning training jobs including coordination with
Flower framework and Ansible deployment.
"""
@staticmethod
async def create_job(
db: Database,
job_data: TrainingJobCreate,
user_id: str
) -> TrainingJob:
"""Create a new training job.
Args:
db: Database connection
job_data: Training job configuration data
user_id: ID of the user creating the job
Returns:
TrainingJob: Created training job instance
Raises:
ValueError: If job configuration is invalid
DatabaseError: If database operation fails
Example:
>>> job_data = TrainingJobCreate(
... name="MNIST Training",
... project_id="proj_123",
... config={"rounds": 10, "min_clients": 2}
... )
>>> job = await TrainingService.create_job(db, job_data, "user_123")
"""
# Validate project access
project = await ProjectService.get_user_project(db, job_data.project_id, user_id)
if not project:
raise ValueError("Project not found or access denied")
# Create job document
job_doc = {
"id": str(uuid4()),
"name": job_data.name,
"description": job_data.description,
"project_id": job_data.project_id,
"user_id": user_id,
"config": job_data.config,
"status": "pending",
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow()
}
# Insert into database
result = await db.training_jobs.insert_one(job_doc)
job_doc["_id"] = result.inserted_id
logger.info(f"Created training job {job_doc['id']} for user {user_id}")
return TrainingJob(**job_doc)
49. No Error Handling Standards¶
Location: Application-wide Risk: Medium - Inconsistent error responses
Recommendation:
# Custom exception classes
class FLPlatformException(Exception):
"""Base exception for FL Platform."""
def __init__(self, message: str, code: str = None, details: dict = None):
self.message = message
self.code = code or self.__class__.__name__
self.details = details or {}
super().__init__(self.message)
class ValidationError(FLPlatformException):
"""Raised when input validation fails."""
pass
class AuthenticationError(FLPlatformException):
"""Raised when authentication fails."""
pass
class AuthorizationError(FLPlatformException):
"""Raised when authorization fails."""
pass
class ResourceNotFoundError(FLPlatformException):
"""Raised when requested resource is not found."""
pass
class DatabaseError(FLPlatformException):
"""Raised when database operation fails."""
pass
# Global exception handler
@app.exception_handler(FLPlatformException)
async def fl_platform_exception_handler(request: Request, exc: FLPlatformException):
"""Handle custom FL Platform exceptions."""
status_code_map = {
ValidationError: 400,
AuthenticationError: 401,
AuthorizationError: 403,
ResourceNotFoundError: 404,
DatabaseError: 500,
}
status_code = status_code_map.get(type(exc), 500)
error_response = {
"error": exc.message,
"code": exc.code,
"details": exc.details,
"timestamp": datetime.utcnow().isoformat()
}
# Log error
logger.error(
f"Exception: {exc.code}",
extra={
"error_code": exc.code,
"error_message": exc.message,
"error_details": exc.details,
"request_path": request.url.path,
"request_method": request.method
}
)
return JSONResponse(
status_code=status_code,
content=error_response
)
# Usage in services
async def get_training_job(db: Database, job_id: str, user_id: str) -> TrainingJob:
"""Get training job by ID with user access validation."""
try:
job_doc = await db.training_jobs.find_one({"id": job_id})
if not job_doc:
raise ResourceNotFoundError(
f"Training job not found: {job_id}",
code="TRAINING_JOB_NOT_FOUND",
details={"job_id": job_id}
)
# Check user access
if job_doc["user_id"] != user_id:
raise AuthorizationError(
"Access denied to training job",
code="TRAINING_JOB_ACCESS_DENIED",
details={"job_id": job_id, "user_id": user_id}
)
return TrainingJob(**job_doc)
except PyMongoError as e:
raise DatabaseError(
"Failed to retrieve training job",
code="DATABASE_QUERY_FAILED",
details={"job_id": job_id, "error": str(e)}
)
50. Missing Input Sanitization¶
Location: User input handling Risk: Medium - Data integrity and security
Recommendation:
import re
import html
from typing import Any, Dict
class InputSanitizer:
"""Utility class for sanitizing user inputs."""
@staticmethod
def sanitize_string(value: str, max_length: int = None) -> str:
"""Sanitize string input."""
if not isinstance(value, str):
raise ValidationError("Input must be a string")
# Remove null bytes
value = value.replace('\x00', '')
# Strip whitespace
value = value.strip()
# HTML escape
value = html.escape(value)
# Check length
if max_length and len(value) > max_length:
raise ValidationError(f"Input too long (max {max_length} characters)")
return value
@staticmethod
def sanitize_filename(filename: str) -> str:
"""Sanitize filename for safe storage."""
if not filename:
raise ValidationError("Filename cannot be empty")
# Remove path separators and dangerous characters
filename = re.sub(r'[<>:"/\\|?*\x00-\x1f]', '', filename)
# Remove leading/trailing dots and spaces
filename = filename.strip('. ')
# Check for reserved names (Windows)
reserved_names = {
'CON', 'PRN', 'AUX', 'NUL',
'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', 'COM8', 'COM9',
'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9'
}
if filename.upper() in reserved_names:
raise ValidationError(f"Reserved filename: {filename}")
# Limit length
if len(filename) > 255:
raise ValidationError("Filename too long")
return filename
@staticmethod
def sanitize_dict(data: Dict[str, Any], allowed_keys: set = None) -> Dict[str, Any]:
"""Sanitize dictionary input."""
if not isinstance(data, dict):
raise ValidationError("Input must be a dictionary")
sanitized = {}
for key, value in data.items():
# Sanitize key
clean_key = InputSanitizer.sanitize_string(key, max_length=100)
# Check allowed keys
if allowed_keys and clean_key not in allowed_keys:
continue
# Sanitize value based on type
if isinstance(value, str):
sanitized[clean_key] = InputSanitizer.sanitize_string(value)
elif isinstance(value, (int, float, bool)):
sanitized[clean_key] = value
elif isinstance(value, dict):
sanitized[clean_key] = InputSanitizer.sanitize_dict(value)
elif isinstance(value, list):
sanitized[clean_key] = [
InputSanitizer.sanitize_string(item) if isinstance(item, str) else item
for item in value[:100] # Limit list size
]
return sanitized
# Enhanced Pydantic validators with sanitization
class TrainingJobCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
project_id: str = Field(..., min_length=1)
config: Dict[str, Any] = Field(default_factory=dict)
@validator('name')
def sanitize_name(cls, v):
return InputSanitizer.sanitize_string(v, max_length=100)
@validator('description')
def sanitize_description(cls, v):
if v is not None:
return InputSanitizer.sanitize_string(v, max_length=500)
return v
@validator('config')
def sanitize_config(cls, v):
allowed_config_keys = {
'rounds', 'min_clients', 'strategy', 'client_resources',
'server_address', 'server_port', 'ssl_enabled'
}
return InputSanitizer.sanitize_dict(v, allowed_config_keys)
Additional Security Hardening (Issues 51-54)¶
51. Missing Rate Limiting Implementation¶
Location: API endpoints Risk: High - DoS attacks
Recommendation:
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
# Rate limiter configuration
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Apply rate limiting to routes
@router.post("/auth/login")
@limiter.limit("5/minute")
async def login(request: Request, user_credentials: UserLogin):
"""Login with rate limiting."""
# Login logic here
pass
@router.post("/training/jobs")
@limiter.limit("10/minute")
async def create_training_job(request: Request, job_data: TrainingJobCreate):
"""Create training job with rate limiting."""
# Job creation logic here
pass
52. Insufficient Session Management¶
Location: Authentication system Risk: Medium - Session hijacking
Recommendation:
# Enhanced session management
class SessionManager:
def __init__(self, redis_client):
self.redis = redis_client
self.session_timeout = 3600 # 1 hour
async def create_session(self, user_id: str, request: Request) -> str:
"""Create new session."""
session_id = secrets.token_urlsafe(32)
session_data = {
"user_id": user_id,
"created_at": datetime.utcnow().isoformat(),
"ip_address": request.client.host,
"user_agent": request.headers.get("user-agent", ""),
"last_activity": datetime.utcnow().isoformat()
}
await self.redis.setex(
f"session:{session_id}",
self.session_timeout,
json.dumps(session_data)
)
return session_id
async def validate_session(self, session_id: str, request: Request) -> Optional[dict]:
"""Validate session and check for anomalies."""
session_data = await self.redis.get(f"session:{session_id}")
if not session_data:
return None
session = json.loads(session_data)
# Check IP address consistency
if session["ip_address"] != request.client.host:
logger.warning(f"IP address mismatch for session {session_id}")
await self.invalidate_session(session_id)
return None
# Update last activity
session["last_activity"] = datetime.utcnow().isoformat()
await self.redis.setex(
f"session:{session_id}",
self.session_timeout,
json.dumps(session)
)
return session
53. Missing CSRF Protection¶
Location: State-changing endpoints Risk: Medium - Cross-site request forgery
Recommendation:
from fastapi_csrf_protect import CsrfProtect
# CSRF protection configuration
@CsrfProtect.load_config
def get_csrf_config():
return CsrfSettings(
secret_key=settings.secret_key,
cookie_samesite="strict",
cookie_secure=True,
cookie_httponly=True
)
csrf_protect = CsrfProtect()
@router.post("/training/jobs")
async def create_training_job(
request: Request,
job_data: TrainingJobCreate,
csrf_protect: CsrfProtect = Depends()
):
"""Create training job with CSRF protection."""
await csrf_protect.validate_csrf(request)
# Job creation logic here
pass
54. Inadequate Audit Logging¶
Location: Security-sensitive operations Risk: Medium - Compliance and forensics
Recommendation:
class AuditLogger:
"""Audit logging for security-sensitive operations."""
def __init__(self):
self.logger = structlog.get_logger("audit")
async def log_authentication(self, username: str, success: bool, ip_address: str, user_agent: str):
"""Log authentication attempts."""
self.logger.info(
"authentication_attempt",
username=username,
success=success,
ip_address=ip_address,
user_agent=user_agent,
timestamp=datetime.utcnow().isoformat()
)
async def log_authorization(self, user_id: str, resource: str, action: str, granted: bool):
"""Log authorization decisions."""
self.logger.info(
"authorization_check",
user_id=user_id,
resource=resource,
action=action,
granted=granted,
timestamp=datetime.utcnow().isoformat()
)
async def log_data_access(self, user_id: str, resource_type: str, resource_id: str, action: str):
"""Log data access operations."""
self.logger.info(
"data_access",
user_id=user_id,
resource_type=resource_type,
resource_id=resource_id,
action=action,
timestamp=datetime.utcnow().isoformat()
)
async def log_admin_action(self, admin_user_id: str, action: str, target: str, details: dict = None):
"""Log administrative actions."""
self.logger.info(
"admin_action",
admin_user_id=admin_user_id,
action=action,
target=target,
details=details or {},
timestamp=datetime.utcnow().isoformat()
)
# Usage in routes
audit_logger = AuditLogger()
@router.post("/auth/login")
async def login(request: Request, user_credentials: UserLogin):
"""Login with audit logging."""
success = False
try:
user = await UserService.authenticate_user(
db, user_credentials.username, user_credentials.password
)
if user:
success = True
# Create session and return token
# ...
else:
raise HTTPException(401, "Invalid credentials")
finally:
await audit_logger.log_authentication(
username=user_credentials.username,
success=success,
ip_address=request.client.host,
user_agent=request.headers.get("user-agent", "")
)
Implementation Priority¶
Week 1 (Critical): - Fix hardcoded secrets (Issues 1, 7, 26) - Add rate limiting (Issues 3, 11) - Implement proper password hashing (Issue 2)
Week 2-4 (High Priority): - Add input validation (Issue 3) - Implement database indexes (Issue 15) - Add security headers (Issue 5) - Fix token storage (Issue 20)
Month 2-3 (Medium Priority): - Implement caching (Issue 16) - Add comprehensive monitoring (Issues 36-41) - Create test coverage (Issues 28-35) - Security audit and penetration testing
This prioritized approach addresses the most critical security vulnerabilities first, followed by performance and operational improvements.