[Feature] Comprehensive Hybrid Parallelism Support (#6389)

This commit is contained in:
Cheng Wan
2025-06-20 14:43:11 -07:00
committed by GitHub
parent 0998808009
commit e879d8b7a8
14 changed files with 3689 additions and 108 deletions

View File

@@ -149,6 +149,8 @@ from sglang.srt.utils import (
kill_itself_when_parent_died,
point_to_point_pyobj,
pyspy_dump_schedulers,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
set_random_seed,
suppress_other_loggers,
@@ -1471,9 +1473,8 @@ class Scheduler(
else:
ret = None
# Handle DP attention
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
ret, _ = self.prepare_dp_attn_batch(ret)
if require_mlp_sync(self.server_args):
ret, _ = self.prepare_mlp_sync_batch(ret)
return ret
@@ -1775,12 +1776,11 @@ class Scheduler(
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
return self.prepare_mlp_sync_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1789,14 +1789,14 @@ class Scheduler(
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
)
@staticmethod
def prepare_dp_attn_batch_raw(
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
@@ -1805,6 +1805,7 @@ class Scheduler(
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool,
):
# Check if other DP workers have running batches
if local_batch is None:
@@ -1879,7 +1880,7 @@ class Scheduler(
if local_batch is not None:
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
if not require_mlp_tp_gather:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else: