FluxCoordinator¶
Low-level API for custom training loops and advanced use cases.
Overview¶
FluxCoordinator is the core orchestration layer that FluxTrainer uses internally. Use it when you need fine-grained control over:
- Training step execution
- Weight synchronization timing
- Async/sync decisions
- Custom batching logic
Basic Usage¶
from flux.coordinator import FluxCoordinator
coordinator = FluxCoordinator(
config=config,
reward_function=my_reward,
)
await coordinator.initialize()
async for result in coordinator.run_training_async(prompts, num_steps=1000):
print(f"Step {result.step}: loss={result.training_result.loss}")
await coordinator.shutdown()
Synchronous Usage¶
StepResult¶
Returned for each training step:
@dataclass
class StepResult:
step: int
training_result: TrainingStepResult | None
staleness_metrics: StalenessMetrics | None
async_decision: AsyncDecision | None
batch_size: int
num_trajectories: int
elapsed_ms: float
metrics: dict[str, float]
Key Methods¶
initialize() / shutdown()¶
await coordinator.initialize() # Set up all components
await coordinator.shutdown() # Clean shutdown
run_training() / run_training_async()¶
Main training loop, sync and async versions.
save_checkpoint() / load_checkpoint()¶
get_statistics()¶
Mode Gate¶
The ModeGate is a state machine that controls sync/async training transitions.
AsyncMode Enum¶
from flux.controller import AsyncMode
class AsyncMode(Enum):
SYNC_BARRIER = auto() # Waiting for all in-flight rollouts
ASYNC_RUNNING = auto() # Normal async operation
THROTTLED = auto() # Capacity exhausted, backpressure active
ModeGateState¶
@dataclass
class ModeGateState:
mode: AsyncMode
reason: str
staleness: float
capacity: int
in_flight: int
buffer_fill_ratio: float = 0.0
ModeGate Class¶
from flux.controller import ModeGate, ModeGateConfig
config = ModeGateConfig(
staleness_threshold=0.3,
capacity_low_watermark=0,
buffer_high_watermark=0.9,
min_barrier_duration_ms=100.0,
)
gate = ModeGate(config)
# Evaluate current state
state = gate.evaluate(
staleness=0.25,
capacity=10,
buffer_fill_ratio=0.5,
in_flight=5,
)
# Check if new rollouts can be submitted
if gate.can_submit_rollout():
submit_rollout()
# Enforce sync barrier (async)
await gate.enforce_barrier(wait_for_in_flight_fn, timeout=30.0)
ModeGateIntegration¶
Helper class for coordinator integration:
from flux.controller import ModeGateIntegration
integration = ModeGateIntegration(
mode_gate=gate,
staleness_manager=staleness_mgr,
trajectory_buffer=buffer,
)
# Check and enforce mode transitions
state = await integration.check_and_enforce(wait_for_in_flight)
# Get available rollout slots
slots = integration.get_rollout_slots()
State Transitions¶
| Current State | Condition | New State | Action |
|---|---|---|---|
| ASYNC_RUNNING | staleness > threshold |
SYNC_BARRIER | Wait for in-flight |
| ASYNC_RUNNING | capacity <= 0 |
THROTTLED | Pause rollouts |
| ASYNC_RUNNING | buffer > 90% |
THROTTLED | Pause rollouts |
| SYNC_BARRIER | In-flight drained | ASYNC_RUNNING | Resume |
| THROTTLED | Capacity recovered | ASYNC_RUNNING | Resume |
See Also¶
- FluxTrainer - High-level API
- Architecture - System design