[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_bool_env_var,
|
||||
kill_process_tree,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
set_gpu_proc_affinity,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
|
||||
enable_custom_logit_processor=False,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
||||
if model_runner.server_args.enable_dp_attention:
|
||||
Scheduler.prepare_dp_attn_batch_raw(
|
||||
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
||||
if require_mlp_sync(model_runner.server_args):
|
||||
Scheduler.prepare_mlp_sync_batch_raw(
|
||||
batch,
|
||||
dp_size=model_runner.server_args.dp_size,
|
||||
attn_tp_size=1,
|
||||
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
||||
tp_cpu_group=model_runner.tp_group.cpu_group,
|
||||
get_idle_batch=None,
|
||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
speculative_num_draft_tokens=None,
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user