Files
sglang/python/sglang/srt/model_parallel.py
Lianmin Zheng 2369e88209 [minor] Clean up unused imports (#2122)
Co-authored-by: rinrin32 <rinrin.int@gmail.com>
2024-11-22 01:50:42 -08:00

99 lines
3.6 KiB
Python

"""
Common utilities for torch model parallelism.
"""
from typing import Optional
import torch
from torch.distributed.device_mesh import DeviceMesh
try:
from torch.distributed.tensor import DTensor, Shard
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
sharded. This is used for the fused wqkv case, where during loading, we
already sharded wq, wk, wv before fusing them.
"""
# Override the _partition_linear_fn in ColwiseParallel
def _partition_linear_fn(self, name, module, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)
class RowwiseParallelMaybeWait(RowwiseParallel):
"""
A version of RowwiseParallel that waits for the output (establish dependency
between comm stream and compute stream in CUDA sense) before going into the
next op. This is needed to workaround the current interaction between
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
)._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh
)
# wait for the output to be ready
if isinstance(outputs, AsyncCollectiveTensor):
return outputs.wait()
else:
return outputs
def tensor_parallel(
module: torch.nn.Module,
device_mesh: Optional[DeviceMesh] = None,
):
"""
Tensor parallelize the model across the given device mesh.
Args:
module (`torch.nn.Module`):
The module to tensor parallelize.
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
for child_name, tp_style in tp_plan.items():
submod = mod.get_submodule(child_name)
if tp_style == "Colwise":
parallelize_module(submod, device_mesh, ColwiseParallel())
elif tp_style == "Rowwise":
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
elif tp_style == "Colwise_Sharded":
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
else:
raise ValueError(f"Unknown TP style {tp_style}")
# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
module.apply(tplize)