[Feature] Comprehensive Hybrid Parallelism Support (#6389)

This commit is contained in:
Cheng Wan
2025-06-20 14:43:11 -07:00
committed by GitHub
parent 0998808009
commit e879d8b7a8
14 changed files with 3689 additions and 108 deletions

View File

@@ -71,6 +71,8 @@ from sglang.srt.utils import (
configure_logger,
get_bool_env_var,
kill_process_tree,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
suppress_other_loggers,
)
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
_maybe_prepare_dp_attn_batch(batch, model_runner)
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
_maybe_prepare_dp_attn_batch(batch, model_runner)
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
if model_runner.server_args.enable_dp_attention:
Scheduler.prepare_dp_attn_batch_raw(
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
if require_mlp_sync(model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
tp_cpu_group=model_runner.tp_group.cpu_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
)