[Feature] Comprehensive Hybrid Parallelism Support (#6389)

This commit is contained in:
Cheng Wan
2025-06-20 14:43:11 -07:00
committed by GitHub
parent 0998808009
commit e879d8b7a8
14 changed files with 3689 additions and 108 deletions

View File

@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -229,7 +229,7 @@ class CommunicateContext:
process_group_sizes: Dict[ScatterMode, int]
attn_tp_rank: int
attn_tp_size: int
local_attn_dp_size: int
attn_dp_size: int
tp_size: int
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
@@ -239,7 +239,7 @@ class CommunicateContext:
def init_new(cls):
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
local_attn_dp_size = get_local_attention_dp_size()
attn_dp_size = get_attention_dp_size()
tp_size = get_tensor_model_parallel_world_size()
process_group_sizes = {
ScatterMode.SCATTERED: 1,
@@ -251,7 +251,7 @@ class CommunicateContext:
process_group_sizes=process_group_sizes,
attn_tp_rank=attn_tp_rank,
attn_tp_size=attn_tp_size,
local_attn_dp_size=local_attn_dp_size,
attn_dp_size=attn_dp_size,
tp_size=tp_size,
)
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
if context.local_attn_dp_size != 1:
if context.attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (