重计算数据保存
使用场景
RoundPipe 默认采用激活重计算:反向传播时重新执行前向计算来恢复中间激活值。大多数情况下,重计算产生的结果与原始前向完全一致(配合 preserve_rng_state=True 保证随机行为一致)。
但某些模型层的 forward 中包含不稳定的操作:即使输入相同,两次执行也可能产生不同的结果,或者某些操作会触发 GPU-CPU 同步导致性能下降。典型场景:
- MoE 模型的 expert 路由:路由决策可能依赖于
torch.topk的选择,重计算时需要保证路由结果一致 - 动态形状操作:如
torch.nonzero()返回的 tensor 形状取决于输入值,且会触发 GPU-CPU 同步 - 条件分支:forward 中根据中间计算结果选择不同的执行路径
对于这些场景,RoundPipe 提供了 save_for_recompute API,允许在前向传播时保存关键数据,在重计算时直接使用保存的数据而非重新计算。
API
save_for_recompute(*data)
在前向传播阶段保存数据,供后续重计算阶段使用。
from roundpipe import save_for_recompute
save_for_recompute(routing_indices, expert_mask)
- 每层的 forward 中最多调用一次
- 保存的 tensor 必须不需要梯度(
requires_grad=False) - 如果当前未启用梯度计算,此函数为空操作
doing_recompute()
检查当前是否处于重计算阶段。
from roundpipe import doing_recompute
if doing_recompute():
# 当前是重计算阶段
...
else:
# 当前是正常前向传播
...
get_recompute_data()
在重计算阶段获取之前通过 save_for_recompute 保存的数据。
from roundpipe import get_recompute_data
data = get_recompute_data() # 返回 tuple
- 只能在重计算上下文中调用,否则会抛出
AssertionError - 即使只保存了一个值,也以 tuple 形式返回
示例
MoE expert 路由重放
import torch
import torch.nn as nn
from roundpipe import save_for_recompute, doing_recompute, get_recompute_data
class MoELayer(nn.Module):
def __init__(self, num_experts, hidden_size, expert_size):
super().__init__()
self.gate = nn.Linear(hidden_size, num_experts)
self.experts = nn.ModuleList([
nn.Linear(hidden_size, expert_size)
for _ in range(num_experts)
])
def forward(self, x):
if doing_recompute():
# 重计算阶段:直接使用保存的路由结果
selected_experts, = get_recompute_data()
else:
# 正常前向:计算路由并保存
gate_logits = self.gate(x)
routing_weights = torch.softmax(gate_logits, dim=-1)
selected_experts = torch.topk(routing_weights, k=2, dim=-1).indices
save_for_recompute(selected_experts)
# 使用路由结果执行 expert 计算
# ...
return output
避免 GPU-CPU 同步
class SparseLayer(nn.Module):
def forward(self, x):
if doing_recompute():
mask, = get_recompute_data()
else:
# torch.nonzero() 会触发 GPU-CPU 同步(需要知道非零元素数量)
# 保存结果避免重计算时再次同步
mask = torch.nonzero(x > 0.5)
save_for_recompute(mask)
# 使用 mask 进行稀疏计算
return x[mask]
条件分支
class ConditionalLayer(nn.Module):
def forward(self, x):
if doing_recompute():
use_branch_a, = get_recompute_data()
else:
# 根据输入统计量选择分支
use_branch_a = (x.mean() > 0).item() # 触发 GPU-CPU 同步
save_for_recompute(use_branch_a)
if use_branch_a:
return self.branch_a(x)
else:
return self.branch_b(x)