GradScaler
Why GradScaler Is Needed
In mixed-precision training, the model uses FP16 for forward and backward computation to improve performance. However, FP16 has a limited numerical range — small gradient values can underflow to zero, causing training instability or failure to converge.
GradScaler addresses this through gradient scaling:
- After computing the loss, it multiplies the loss by a large scale factor (e.g., 65536).
- Backpropagation produces proportionally larger gradients, preventing underflow.
- Before the optimizer update, gradients are divided by the scale factor to restore their true values.
- If inf/NaN values are detected in the gradients (indicating the scale factor is too large), the update is skipped and the scale factor is reduced.
Basic Usage
RoundPipe provides its own GradScaler implementation. The interface is fully compatible with torch.amp.GradScaler, but it additionally supports RoundPipe's asynchronous optimizer execution model.
A complete mixed-precision training loop:
from roundpipe import RoundPipe, OptimizerCtx, GradScaler
from roundpipe.optim import Adam
model = RoundPipe(my_model.to(torch.float16), optim_dtype=torch.float32)
with OptimizerCtx():
optimizer = Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
for data, labels in dataloader:
# 1. forward_backward: call scaler.scale() inside loss_fn
loss = model.forward_backward(
input_args=(data.to(torch.float16),),
label=labels,
loss_fn=lambda outputs, labels: scaler.scale(
criterion(outputs.float(), labels)
) / num_microbatch,
)
# 2. step: call scaler.step() inside step_fn
model.step(lambda: (
scaler.step(optimizer),
optimizer.zero_grad(),
))
# 3. update: adjust the scale factor
scaler.update()
# Get the true loss value (divide by scale factor)
real_loss = loss.item() / scaler.get_scale()
What each step does:
scaler.scale(loss): Multiplies the loss by the scale factor. Placed insideloss_fnso that every microbatch's backward pass uses the scaled loss.scaler.step(optimizer): Automatically runsunscale_(divides gradients by the scale factor) + checks for inf/NaN + callsoptimizer.step(). If inf/NaN is found, the update is skipped.scaler.update(): Adjusts the scale factor based on whether the update was skipped. Must be called on the main thread.
Differences from PyTorch's GradScaler
The key difference between RoundPipe's GradScaler and PyTorch's native torch.amp.GradScaler is thread safety.
RoundPipe's optimizer updates run asynchronously in a background thread by default (via model.step(is_async=True)). scaler.scale() is called on the GPU computation thread, while scaler.step() and scaler.update() run on the optimizer thread. PyTorch's native GradScaler does not support this cross-thread usage pattern, so you must use RoundPipe's version.
Do not use PyTorch's native GradScaler
# ❌ Wrong: PyTorch's GradScaler does not support async optimizers
scaler = torch.amp.GradScaler()
# ✅ Correct: use RoundPipe's GradScaler
from roundpipe import GradScaler
scaler = GradScaler()
Manual Unscaling
If you need to perform additional operations on gradients before the optimizer update (e.g., gradient clipping), call unscale_ manually:
def step_fn():
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]['params'], max_norm=1.0)
scaler.step(optimizer)
optimizer.zero_grad()
model.step(step_fn)
If you do not call unscale_ manually, scaler.step() performs it automatically — no extra action needed.