Skip to content

Custom Microbatch Splitting

RoundPipe splits input data into microbatches for pipeline execution. By default it infers the splitting strategy automatically, but non-standard input/output formats may require you to specify custom split and merge strategies.

Default Splitting Behavior

Automatic Input Splitting

When split_input is None (default), RoundPipe recursively traverses the nested input structure (tuples, lists, dicts) and applies these rules to each leaf node:

  • Multi-dimensional tensors: split evenly along dimension 0 (the batch dimension).
  • Scalar tensors and non-tensor types: replicated to every microbatch.
# Example: automatic splitting
# input args = (images,), images.shape = (12, 3, 224, 224)
# num_microbatch = 4
# β†’ each microbatch gets images.shape = (3, 3, 224, 224)

Automatic Label Splitting

When split_label is None (default), the same rules as input splitting apply.

Automatic Output Merging

When merge_output is None or True (default), RoundPipe recursively traverses the output structure and applies these rules to each leaf node:

  • Multi-dimensional tensors: concatenated along dimension 0 (torch.cat).
  • Scalar tensors: averaged across microbatches.
  • Non-tensor types: checked for equality across all microbatches; returns the value if equal, raises an error otherwise.

split_input

When the defaults don't fit your needs, use split_input to specify a custom strategy. There are two approaches:

Approach 1: Split Spec

Use TensorChunkSpec and _Replicate to annotate how each input should be split:

from torch.distributed.pipelining.microbatch import TensorChunkSpec, _Replicate
from roundpipe import RoundPipeRunConfig

# Scenario: model takes (images, scale_factor)
# Split images along the batch dim; replicate scale_factor to every microbatch
config = RoundPipeRunConfig(
    split_input=(
        (TensorChunkSpec(0), _Replicate),  # args spec
        None,                               # kwargs spec (auto-infer)
    )
)

# Scenario: an input needs splitting along a non-batch dimension
# e.g., (tokens, position_ids) where position_ids splits along dim=1
config = RoundPipeRunConfig(
    split_input=(
        (TensorChunkSpec(0), TensorChunkSpec(1)),  # args spec
        None,
    )
)

The spec structure must match the input structure one-to-one. The args spec is a tuple and the kwargs spec is a dict (or None for auto-inference).

Approach 2: Custom Function

When specs aren't expressive enough, provide a fully custom split function:

def custom_split(args, kwargs, num_microbatch):
    """
    args: original positional argument tuple
    kwargs: original keyword argument dict
    num_microbatch: number of microbatches
    Returns: (args_list, kwargs_list), each of length num_microbatch
    """
    images, masks = args
    chunks_images = images.chunk(num_microbatch)
    chunks_masks = masks.chunk(num_microbatch)
    args_list = [(img, mask) for img, mask in zip(chunks_images, chunks_masks)]
    kwargs_list = [kwargs] * num_microbatch  # replicate kwargs to every microbatch
    return args_list, kwargs_list

config = RoundPipeRunConfig(split_input=custom_split)

A more complex example β€” input is a dict with per-field splitting strategies:

def split_transformer_input(args, kwargs, num_microbatch):
    # kwargs = {"input_ids": ..., "attention_mask": ..., "position_ids": ...}
    input_ids_chunks = kwargs["input_ids"].chunk(num_microbatch)
    mask_chunks = kwargs["attention_mask"].chunk(num_microbatch)
    pos_chunks = kwargs["position_ids"].chunk(num_microbatch)
    args_list = [()] * num_microbatch
    kwargs_list = [
        {"input_ids": ids, "attention_mask": mask, "position_ids": pos}
        for ids, mask, pos in zip(input_ids_chunks, mask_chunks, pos_chunks)
    ]
    return args_list, kwargs_list

config = RoundPipeRunConfig(split_input=split_transformer_input)

split_label

Similar to split_input, split_label controls how labels are split.

Split Spec

from torch.distributed.pipelining.microbatch import TensorChunkSpec, _Replicate

# Label is a tuple: (targets, sample_weights)
# Both split along the batch dimension
config = RoundPipeRunConfig(
    split_label=(TensorChunkSpec(0), TensorChunkSpec(0))
)

# Label is a tuple: (targets, class_weights)
# targets split along batch dim; class_weights replicated (shared across samples)
config = RoundPipeRunConfig(
    split_label=(TensorChunkSpec(0), _Replicate)
)

Custom Function

def custom_split_label(label, num_microbatch):
    """
    label: original label
    num_microbatch: number of microbatches
    Returns: List[label] of length num_microbatch
    """
    targets, weights = label
    chunks_targets = targets.chunk(num_microbatch)
    chunks_weights = weights.chunk(num_microbatch)
    return [(t, w) for t, w in zip(chunks_targets, chunks_weights)]

config = RoundPipeRunConfig(split_label=custom_split_label)

merge_output

Controls how per-microbatch outputs are combined into the final output.

Split Spec

Similar to splitting, you can use specs to annotate how each output field should be merged:

from torch.distributed.pipelining.microbatch import TensorChunkSpec, _Replicate, _CustomReducer

# Output is (logits, hidden_states)
# Both concatenated along the batch dimension
config = RoundPipeRunConfig(
    merge_output=(TensorChunkSpec(0), TensorChunkSpec(0))
)

# Custom reducer: sum the losses
sum_reducer = _CustomReducer(torch.tensor(0.0), lambda x, y: x + y)
config = RoundPipeRunConfig(merge_output=sum_reducer)

Custom Function

def custom_merge(outputs):
    """
    outputs: List[output], one per microbatch
    Returns: merged output
    """
    # Example: HuggingFace model output is an object; only logits are needed
    logits = torch.cat([out.logits for out in outputs], dim=0)
    return logits

config = RoundPipeRunConfig(merge_output=custom_merge)

Disabling Merging

Set merge_output=False to disable output merging. Each leaf variable in the output will be returned as a RoundPipePackedData (a list subclass) containing per-microbatch outputs along with their CUDA transfer events.

config = RoundPipeRunConfig(merge_output=False)
output = model(data, roundpipe_run_config=config)
# Tensors in the output are RoundPipePackedData objects
output.synchronize()  # Wait for all transfers to complete
# Then access per-microbatch outputs like a regular list

This is useful when piping one RoundPipe model's output directly into another, avoiding unnecessary synchronization and data copies. wrap_model_to_roundpipe automatically sets merge_output=False when recursively wrapping non-final modules.