Adaptive Async in Practice¶
Master the adaptive async controller for optimal training performance.
Time: 60 minutes Prerequisites: Multi-GPU Training
Overview¶
The adaptive async controller is what makes Flux unique. In this tutorial, you'll learn:
- How the PID controller works
- Tuning staleness targets
- Monitoring controller behavior
- Debugging common issues
- Advanced configuration
Understanding the Controller¶
The Control Loop¶
graph LR
A[Measure Staleness] --> B[Compute Error]
B --> C[PID Controller]
C --> D[Adjust Async Ratio]
D --> E[Training Step]
E --> A
- Measure: Compute staleness from batch
- Error:
target_staleness - current_staleness - PID: Compute adjustment using P, I, D terms
- Adjust: Update async ratio within bounds
- Repeat: Continuous feedback loop
Key Variables¶
# Target: Where we want staleness to be
target_staleness = 0.15
# Current: Measured staleness (EMA smoothed)
current_staleness = 0.12
# Error: Difference from target
error = target_staleness - current_staleness # 0.03
# Async ratio: Current position on sync-async spectrum
async_ratio = 0.45 # 45% async
Configuration Deep Dive¶
Basic Configuration¶
adaptive_async:
# Target staleness level
target_staleness: 0.15
# Bounds on async ratio
min_async_ratio: 0.1 # Never fully sync
max_async_ratio: 0.9 # Never fully async
# PID gains
kp: 0.1 # Proportional
ki: 0.01 # Integral
kd: 0.05 # Derivative
# Smoothing
ema_alpha: 0.1 # EMA smoothing for staleness
PID Tuning Guide¶
| Gain | Effect | Too Low | Too High |
|---|---|---|---|
kp |
Immediate response | Slow adaptation | Oscillation |
ki |
Steady-state correction | Offset from target | Windup, overshoot |
kd |
Dampening | Overshoot | Noise sensitivity |
Starting Point: kp=0.1, ki=0.01, kd=0.05
Staleness Components¶
Staleness is computed from three signals:
# 1. KL Divergence
kl = compute_kl(current_policy, behavior_policy)
kl_contrib = min(1.0, kl / 0.1)
# 2. Importance Weight Variance
iw_var = compute_iw_variance(batch)
iw_contrib = min(1.0, iw_var / 2.0)
# 3. Version Gap
version_gap = current_version - mean(batch_versions)
version_contrib = min(1.0, version_gap / 5)
# Combined (weighted sum)
staleness = 0.4 * kl_contrib + 0.3 * iw_contrib + 0.3 * version_contrib
Monitoring¶
Key Metrics to Track¶
@trainer.add_step_callback
def log_adaptive_metrics(result):
m = result.metrics
print(f"Step {result.step}: "
f"staleness={m['staleness']:.3f} "
f"(target={m['target_staleness']:.3f}), "
f"async_ratio={m['async_ratio']:.3f}, "
f"error={m['pid_error']:.4f}")
Healthy Behavior¶
Step 100: staleness=0.12 (target=0.15), async_ratio=0.42, error=0.030
Step 200: staleness=0.14 (target=0.15), async_ratio=0.48, error=0.010
Step 300: staleness=0.16 (target=0.15), async_ratio=0.45, error=-0.010
Step 400: staleness=0.15 (target=0.15), async_ratio=0.46, error=0.000
↑ Controller stabilizes around target
Warning Signs¶
# Bad: Staleness stuck high
Step 100: staleness=0.35 (target=0.15), async_ratio=0.10 # Stuck at min!
# Bad: Oscillating wildly
Step 100: staleness=0.05, async_ratio=0.80
Step 101: staleness=0.30, async_ratio=0.15
Step 102: staleness=0.08, async_ratio=0.75
Tuning Strategies¶
Strategy 1: Conservative Start¶
For critical training or new tasks:
adaptive_async:
target_staleness: 0.10 # Low staleness
max_async_ratio: 0.5 # Limited async
kp: 0.05 # Slower adaptation
Strategy 2: High Throughput¶
When training is stable and you need speed:
adaptive_async:
target_staleness: 0.25 # Higher staleness OK
max_async_ratio: 0.9 # More async allowed
kp: 0.15 # Faster adaptation
Strategy 3: Phase-Aware¶
Different settings for different training phases:
adaptive_async:
target_staleness: 0.15
phase_schedule:
warmup:
steps: 100
target_staleness: 0.05 # Very fresh during warmup
main:
target_staleness: 0.15 # Normal during main training
cooldown:
steps: 50
target_staleness: 0.25 # More async during cooldown
Debugging Common Issues¶
Issue: Staleness Too High¶
Symptoms: staleness > target + 0.1 consistently
Causes: 1. Too many stale trajectories in buffer 2. Slow weight sync 3. SGLang inference too slow
Solutions:
adaptive_async:
max_version_gap: 3 # Reduce max staleness
sync_threshold: 0.20 # Sync sooner
weight_sync:
method: delta # Faster sync
interval: 1 # Sync every step
Issue: Async Ratio Stuck at Minimum¶
Symptoms: async_ratio = 0.1 (min) for many steps
Causes: 1. Target staleness too low 2. High staleness variance 3. Controller can't recover
Solutions:
adaptive_async:
target_staleness: 0.20 # Increase target
kp: 0.15 # Faster response
min_async_ratio: 0.2 # Higher floor
Issue: Oscillating Controller¶
Symptoms: Rapid swings between min/max async ratio
Causes: 1. PID gains too high 2. Measurement noise 3. Inconsistent workload
Solutions:
adaptive_async:
kp: 0.05 # Reduce proportional
kd: 0.1 # Increase dampening
ema_alpha: 0.05 # More smoothing
Visualization¶
Plot Controller Behavior¶
import matplotlib.pyplot as plt
# Collect metrics during training
history = {"staleness": [], "async_ratio": [], "error": []}
@trainer.add_step_callback
def collect_metrics(result):
history["staleness"].append(result.metrics["staleness"])
history["async_ratio"].append(result.metrics["async_ratio"])
history["error"].append(result.metrics["pid_error"])
# After training, plot
fig, axes = plt.subplots(3, 1, figsize=(10, 8))
axes[0].plot(history["staleness"])
axes[0].axhline(y=0.15, color='r', linestyle='--', label='target')
axes[0].set_ylabel("Staleness")
axes[0].legend()
axes[1].plot(history["async_ratio"])
axes[1].set_ylabel("Async Ratio")
axes[2].plot(history["error"])
axes[2].axhline(y=0, color='r', linestyle='--')
axes[2].set_ylabel("PID Error")
axes[2].set_xlabel("Step")
plt.tight_layout()
plt.savefig("controller_behavior.png")
Advanced: Custom Controller¶
For research, you can implement a custom controller:
from flux.controller import BaseAsyncController
class CustomController(BaseAsyncController):
def __init__(self, config):
super().__init__(config)
self.history = []
def update(self, staleness: float) -> float:
self.history.append(staleness)
# Custom logic: use moving average
if len(self.history) > 10:
trend = self.history[-1] - self.history[-10]
else:
trend = 0
# Adjust based on trend
if trend > 0.02: # Staleness increasing
self.async_ratio *= 0.95 # Reduce async
elif trend < -0.02: # Staleness decreasing
self.async_ratio *= 1.05 # Increase async
self.async_ratio = max(0.1, min(0.9, self.async_ratio))
return self.async_ratio
# Use custom controller
config = FluxConfig(
adaptive_async={"controller": CustomController}
)
Next Steps¶
- Production Deployment - Scale to production
- Performance Optimization - Maximum throughput
- Concepts: Adaptive Async - Theory deep dive