[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -229,7 +229,7 @@ class CommunicateContext:
|
||||
process_group_sizes: Dict[ScatterMode, int]
|
||||
attn_tp_rank: int
|
||||
attn_tp_size: int
|
||||
local_attn_dp_size: int
|
||||
attn_dp_size: int
|
||||
tp_size: int
|
||||
|
||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||
@@ -239,7 +239,7 @@ class CommunicateContext:
|
||||
def init_new(cls):
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
local_attn_dp_size = get_local_attention_dp_size()
|
||||
attn_dp_size = get_attention_dp_size()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
process_group_sizes = {
|
||||
ScatterMode.SCATTERED: 1,
|
||||
@@ -251,7 +251,7 @@ class CommunicateContext:
|
||||
process_group_sizes=process_group_sizes,
|
||||
attn_tp_rank=attn_tp_rank,
|
||||
attn_tp_size=attn_tp_size,
|
||||
local_attn_dp_size=local_attn_dp_size,
|
||||
attn_dp_size=attn_dp_size,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
|
||||
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
attn_tp_all_gather(
|
||||
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
||||
)
|
||||
if context.local_attn_dp_size != 1:
|
||||
if context.attn_dp_size != 1:
|
||||
if context.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
|
||||
Reference in New Issue
Block a user