Runtime Context
RoundPipe uses thread-local context managers to track the current execution phase (forward pass, recomputation, optimizer update). Users can use these APIs to save data during a model's forward pass for use during recomputation, or to correctly access parameters during optimizer-related operations.
roundpipe.save_for_recompute
roundpipe.save_for_recompute(*data: Any) -> None
Save data in the current forward pass context for use during the subsequent recomputation phase.
RoundPipe re-executes the forward pass during backward propagation to recover intermediate activations. If certain intermediate results in the model cause GPU-CPU synchronization (e.g., torch.nonzero()), you can save them during the forward pass using this function to avoid synchronization overhead during recomputation.
Parameters:
*data: Arbitrary data to save. Saved tensors must not require gradients (requires_grad=False).
Notes:
- Can be called at most once per layer's forward pass.
- If gradient computation is not currently enabled, this function is a no-op.
- Saved data can be retrieved during the recomputation phase via
get_recompute_data().
Example:
import torch
import torch.nn as nn
from roundpipe import save_for_recompute, doing_recompute, get_recompute_data
class MyLayer(nn.Module):
def forward(self, x):
if doing_recompute():
# Recomputation phase: use the saved mask directly
mask, = get_recompute_data()
else:
# Forward pass phase: generate and save the mask
mask = torch.nonzero(x)
save_for_recompute(mask)
return x, mask
roundpipe.doing_recompute
roundpipe.doing_recompute() -> bool
Check whether the current scope is in the recomputation phase.
RoundPipe re-executes the forward pass during backward propagation to recover activations. This function can be used within forward pass code to distinguish between the initial forward computation and the recomputation phase, allowing different logic to be executed accordingly.
Returns:
bool:Trueif currently within a recomputation context;Falseotherwise.
roundpipe.get_recompute_data
roundpipe.get_recompute_data() -> tuple
Retrieve data saved during the forward pass via save_for_recompute().
Returns:
tuple: The saved data. Even if only a single item was saved, it is returned as a tuple.
Notes:
- This function can only be called within a recomputation context. Calling it outside a recomputation context will raise an
AssertionError.
roundpipe.OptimizerCtx
class roundpipe.OptimizerCtx
A context manager that marks the current scope as performing optimizer-related operations.
Within this context, RoundPipe redirects .parameters() and .named_parameters() to .optim_parameters() and .optim_named_parameters(). This allows users to create optimizers using the standard PyTorch pattern without explicitly calling the optimizer-specific parameter interfaces.
Use Cases:
Use this context manager when creating an optimizer or performing other operations that need access to optimizer parameters.
Example:
from roundpipe import RoundPipe, OptimizerCtx
from roundpipe.optim import Adam
model = RoundPipe(my_model)
# Method 1: Using OptimizerCtx (recommended, consistent with PyTorch conventions)
with OptimizerCtx():
optimizer = Adam(model.parameters(), lr=0.001)
# Method 2: Using optim_parameters directly (equivalent)
optimizer = Adam(model.optim_parameters(), lr=0.001)