[Feature] Comprehensive Hybrid Parallelism Support (#6389)
This commit is contained in:
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_bool_env_var,
|
||||
kill_process_tree,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
set_gpu_proc_affinity,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
|
||||
enable_custom_logit_processor=False,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.output_ids = input_token_ids
|
||||
batch.prepare_for_decode()
|
||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||
logits_output, _ = model_runner.forward(forward_batch)
|
||||
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
||||
if model_runner.server_args.enable_dp_attention:
|
||||
Scheduler.prepare_dp_attn_batch_raw(
|
||||
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
||||
if require_mlp_sync(model_runner.server_args):
|
||||
Scheduler.prepare_mlp_sync_batch_raw(
|
||||
batch,
|
||||
dp_size=model_runner.server_args.dp_size,
|
||||
attn_tp_size=1,
|
||||
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
||||
tp_cpu_group=model_runner.tp_group.cpu_group,
|
||||
get_idle_batch=None,
|
||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
speculative_num_draft_tokens=None,
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import require_mlp_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -649,10 +650,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
prepare_dp_attn_flag = (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
)
|
||||
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
@@ -661,14 +659,14 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_dp_attn_flag:
|
||||
if prepare_mlp_sync_flag:
|
||||
self._prepare_idle_batch_and_run(None)
|
||||
else:
|
||||
if prepare_dp_attn_flag:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
elif prepare_dp_attn_flag:
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
if batch is None and (
|
||||
@@ -699,10 +697,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.cur_batch = batch
|
||||
last_batch_in_queue = False
|
||||
|
||||
prepare_dp_attn_flag = (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
)
|
||||
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
@@ -711,7 +706,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_dp_attn_flag:
|
||||
if prepare_mlp_sync_flag:
|
||||
batch_, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
@@ -719,8 +714,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
result_queue.append((batch_.copy(), result))
|
||||
last_batch_in_queue = True
|
||||
else:
|
||||
if prepare_dp_attn_flag:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
@@ -735,7 +730,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
last_batch_in_queue = True
|
||||
|
||||
elif prepare_dp_attn_flag:
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
@@ -765,13 +760,13 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.last_batch = batch
|
||||
self.last_batch_in_queue = last_batch_in_queue
|
||||
|
||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
if not delay_process:
|
||||
self.process_batch_result(batch, result)
|
||||
self.prepare_mlp_sync_batch(batch, result)
|
||||
return batch, result
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import require_mlp_sync
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
# Handle DP attention
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
# Handle DP attention
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
|
||||
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -229,7 +229,7 @@ class CommunicateContext:
|
||||
process_group_sizes: Dict[ScatterMode, int]
|
||||
attn_tp_rank: int
|
||||
attn_tp_size: int
|
||||
local_attn_dp_size: int
|
||||
attn_dp_size: int
|
||||
tp_size: int
|
||||
|
||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||
@@ -239,7 +239,7 @@ class CommunicateContext:
|
||||
def init_new(cls):
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
local_attn_dp_size = get_local_attention_dp_size()
|
||||
attn_dp_size = get_attention_dp_size()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
process_group_sizes = {
|
||||
ScatterMode.SCATTERED: 1,
|
||||
@@ -251,7 +251,7 @@ class CommunicateContext:
|
||||
process_group_sizes=process_group_sizes,
|
||||
attn_tp_rank=attn_tp_rank,
|
||||
attn_tp_size=attn_tp_size,
|
||||
local_attn_dp_size=local_attn_dp_size,
|
||||
attn_dp_size=attn_dp_size,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
|
||||
@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
attn_tp_all_gather(
|
||||
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
||||
)
|
||||
if context.local_attn_dp_size != 1:
|
||||
if context.attn_dp_size != 1:
|
||||
if context.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
|
||||
@@ -165,7 +165,8 @@ def disable_dp_size():
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
dp_rank = get_local_attention_dp_rank()
|
||||
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
||||
dp_rank = get_attention_dp_rank()
|
||||
|
||||
if forward_batch.dp_local_start_pos is None:
|
||||
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||
|
||||
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
dp_gather_replicate,
|
||||
dp_scatter,
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_rank,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
@@ -171,7 +171,7 @@ class LogitsMetadata:
|
||||
return
|
||||
|
||||
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||
dp_rank = get_local_attention_dp_rank()
|
||||
dp_rank = get_attention_dp_rank()
|
||||
if dp_rank == 0:
|
||||
dp_local_start_pos = torch.zeros_like(
|
||||
self.global_num_tokens_for_logprob_gpu[0]
|
||||
|
||||
@@ -149,6 +149,8 @@ from sglang.srt.utils import (
|
||||
kill_itself_when_parent_died,
|
||||
point_to_point_pyobj,
|
||||
pyspy_dump_schedulers,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
set_gpu_proc_affinity,
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
@@ -1471,9 +1473,8 @@ class Scheduler(
|
||||
else:
|
||||
ret = None
|
||||
|
||||
# Handle DP attention
|
||||
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
|
||||
ret, _ = self.prepare_dp_attn_batch(ret)
|
||||
if require_mlp_sync(self.server_args):
|
||||
ret, _ = self.prepare_mlp_sync_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -1775,12 +1776,11 @@ class Scheduler(
|
||||
self.return_health_check_ct -= 1
|
||||
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
||||
|
||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||
return self.prepare_dp_attn_batch_raw(
|
||||
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
||||
return self.prepare_mlp_sync_batch_raw(
|
||||
local_batch,
|
||||
dp_size=self.server_args.dp_size,
|
||||
attn_tp_size=self.attn_tp_size,
|
||||
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
||||
tp_cpu_group=self.tp_cpu_group,
|
||||
get_idle_batch=self.get_idle_batch,
|
||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||
@@ -1789,14 +1789,14 @@ class Scheduler(
|
||||
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
||||
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
||||
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prepare_dp_attn_batch_raw(
|
||||
def prepare_mlp_sync_batch_raw(
|
||||
local_batch: ScheduleBatch,
|
||||
dp_size,
|
||||
attn_tp_size: int,
|
||||
moe_dense_tp_size: Optional[int],
|
||||
tp_cpu_group,
|
||||
get_idle_batch,
|
||||
disable_cuda_graph: bool,
|
||||
@@ -1805,6 +1805,7 @@ class Scheduler(
|
||||
enable_two_batch_overlap: bool,
|
||||
enable_deepep_moe: bool,
|
||||
deepep_mode: DeepEPMode,
|
||||
require_mlp_tp_gather: bool,
|
||||
):
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
@@ -1879,7 +1880,7 @@ class Scheduler(
|
||||
|
||||
if local_batch is not None:
|
||||
# TODO: handle the case when moe_dense_tp_size != 1
|
||||
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
|
||||
if not require_mlp_tp_gather:
|
||||
local_batch.global_num_tokens = [num_tokens]
|
||||
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
||||
else:
|
||||
|
||||
@@ -46,6 +46,9 @@ from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
get_device_memory_capacity,
|
||||
rank0_log,
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -207,8 +210,9 @@ class CudaGraphRunner:
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.enable_two_batch_overlap = (
|
||||
model_runner.server_args.enable_two_batch_overlap
|
||||
)
|
||||
@@ -299,18 +303,28 @@ class CudaGraphRunner:
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
if self.require_gathered_buffer:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -322,7 +336,7 @@ class CudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -459,7 +473,7 @@ class CudaGraphRunner:
|
||||
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -472,6 +486,16 @@ class CudaGraphRunner:
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -607,7 +631,7 @@ class CudaGraphRunner:
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -642,7 +666,7 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
if enable_num_token_non_padded(self.model_runner.server_args):
|
||||
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
||||
|
||||
@@ -1621,8 +1621,6 @@ class DeepseekV2Model(nn.Module):
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
def get_input_embeddings(self) -> torch.Tensor:
|
||||
return self.embed_tokens
|
||||
|
||||
@@ -1706,7 +1704,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_local_attention_dp_size()
|
||||
|
||||
self._routed_experts_weights_of_layer = LazyValue(
|
||||
lambda: {
|
||||
|
||||
@@ -387,7 +387,6 @@ class ServerArgs:
|
||||
), "Please enable dp attention when setting enable_dp_attention. "
|
||||
|
||||
# DeepEP MoE
|
||||
self.enable_sp_layernorm = False
|
||||
if self.enable_deepep_moe:
|
||||
if self.deepep_mode == "auto":
|
||||
assert (
|
||||
@@ -397,9 +396,6 @@ class ServerArgs:
|
||||
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
||||
self.disable_cuda_graph = True
|
||||
self.ep_size = self.tp_size
|
||||
self.enable_sp_layernorm = (
|
||||
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
||||
)
|
||||
logger.warning(
|
||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
@@ -20,6 +20,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
@@ -39,8 +44,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.dp_size = self.model_runner.dp_size
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||
@@ -88,8 +94,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||
if self.require_gathered_buffer:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
@@ -97,12 +102,19 @@ class EAGLEDraftCudaGraphRunner:
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
@@ -114,8 +126,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
|
||||
if self.require_mlp_tp_gather:
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
@@ -153,7 +164,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
topk_index = self.topk_index[:num_seqs]
|
||||
hidden_states = self.hidden_states[:num_seqs]
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -177,6 +188,24 @@ class EAGLEDraftCudaGraphRunner:
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -259,7 +288,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -286,7 +315,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
||||
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
|
||||
@@ -21,6 +21,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
@@ -35,8 +40,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
@@ -92,7 +98,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
@@ -100,13 +106,19 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(self.dp_size,), dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
assert self.require_attn_tp_gather
|
||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||
(1,), dtype=torch.int32
|
||||
)
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
@@ -117,7 +129,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
if not forward_batch.can_run_dp_cuda_graph:
|
||||
return False
|
||||
total_batch_size = (
|
||||
@@ -160,7 +172,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
positions = self.positions[:num_tokens]
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
@@ -184,6 +196,24 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
elif self.require_attn_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -270,7 +300,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
# in the batch, which will not be counted as num_seqs
|
||||
raw_bs = forward_batch.batch_size
|
||||
num_tokens = forward_batch.input_ids.shape[0]
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_mlp_tp_gather:
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
@@ -299,7 +329,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
if self.require_gathered_buffer:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||
forward_batch.global_num_tokens_for_logprob_gpu
|
||||
|
||||
@@ -2303,6 +2303,51 @@ class Withable(Generic[T]):
|
||||
self._value = None
|
||||
|
||||
|
||||
def require_mlp_tp_gather(server_args):
|
||||
"""
|
||||
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
|
||||
"""
|
||||
if server_args.enable_dp_attention:
|
||||
assert server_args.dp_size > 1, "dp_size must be greater than 1"
|
||||
if (
|
||||
server_args.moe_dense_tp_size is None
|
||||
): # TODO(ch-wan): some MoE models do not have dense layers
|
||||
return True
|
||||
elif not server_args.enable_dp_lm_head:
|
||||
return True
|
||||
elif not server_args.enable_deepep_moe:
|
||||
return True
|
||||
else:
|
||||
return (
|
||||
server_args.moe_dense_tp_size
|
||||
> server_args.tp_size // server_args.dp_size
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def require_attn_tp_gather(server_args):
|
||||
"""
|
||||
Check if the input of attention is scattered.
|
||||
"""
|
||||
assert server_args.moe_dense_tp_size in [1, None]
|
||||
if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
|
||||
if server_args.enable_dp_attention:
|
||||
return server_args.dp_size < server_args.tp_size
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def require_gathered_buffer(server_args):
|
||||
return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args)
|
||||
|
||||
|
||||
def require_mlp_sync(server_args):
|
||||
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
lhs: Optional[torch.Tensor],
|
||||
rhs: Optional[torch.Tensor],
|
||||
|
||||
Reference in New Issue
Block a user