diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index f03b6333f..2dec296a7 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -32,6 +32,8 @@ from sglang.srt.layers.dp_attention import ( get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, + get_global_dp_buffer, + get_local_dp_buffer, ) from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -319,7 +321,7 @@ class CommunicateSimpleFn: context: CommunicateContext, ) -> torch.Tensor: hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + get_local_dp_buffer(), hidden_states, ) attn_tp_all_gather_into_tensor( @@ -408,9 +410,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ): if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1: residual, local_residual = ( - torch.empty_like( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]] - ), + get_local_dp_buffer(), residual, ) attn_tp_all_gather_into_tensor(residual, local_residual) @@ -424,7 +424,7 @@ class CommunicateWithAllReduceAndLayerNormFn: residual = hidden_states hidden_states = layernorm(hidden_states) hidden_states, local_hidden_states = ( - torch.empty_like(forward_batch.gathered_buffer), + get_global_dp_buffer(), hidden_states, ) dp_gather_partial(hidden_states, local_hidden_states, forward_batch) @@ -548,7 +548,7 @@ class CommunicateSummableTensorPairFn: allow_reduce_scatter: bool = False, ): hidden_states, global_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + get_local_dp_buffer(), hidden_states, ) if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): @@ -569,7 +569,7 @@ class CommunicateSummableTensorPairFn: hidden_states += residual residual = None hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + get_local_dp_buffer(), hidden_states, ) attn_tp_all_gather_into_tensor( diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 21d44561d..3d5d30890 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -4,7 +4,7 @@ import functools import logging from contextlib import contextmanager from enum import IntEnum, auto -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import torch import triton @@ -18,21 +18,26 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_reduce, ) +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.server_args import ServerArgs + logger = logging.getLogger(__name__) if TYPE_CHECKING: from sglang.srt.model_executor.forward_batch_info import ForwardBatch -_ATTN_TP_GROUP = None -_ATTN_TP_RANK = None -_ATTN_TP_SIZE = None -_ATTN_DP_RANK = None -_ATTN_DP_SIZE = None -_LOCAL_ATTN_DP_SIZE = None -_LOCAL_ATTN_DP_RANK = None +_ATTN_TP_GROUP: Optional[GroupCoordinator] = None +_ATTN_TP_RANK: Optional[int] = None +_ATTN_TP_SIZE: Optional[int] = None +_ATTN_DP_RANK: Optional[int] = None +_ATTN_DP_SIZE: Optional[int] = None +_LOCAL_ATTN_DP_SIZE: Optional[int] = None +_LOCAL_ATTN_DP_RANK: Optional[int] = None +_ENABLE_DP_ATTENTION_FLAG: bool = False -class DPPaddingMode(IntEnum): +class DpPaddingMode(IntEnum): # Padding tokens to max length and then gather tokens using `all_gather_into_tensor` MAX_LEN = auto() @@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum): SUM_LEN = auto() def is_max_len(self): - return self == DPPaddingMode.MAX_LEN + return self == DpPaddingMode.MAX_LEN def is_sum_len(self): - return self == DPPaddingMode.SUM_LEN + return self == DpPaddingMode.SUM_LEN @classmethod - def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode: + def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode: # we choose the mode that minimizes the communication cost max_len = max(global_num_tokens) sum_len = sum(global_num_tokens) @@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum): return cls.SUM_LEN @classmethod - def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode: + def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode: return cls.MAX_LEN +class _DpGatheredBufferWrapper: + + _hidden_size: int + _dtype: torch.dtype + _device: torch.device + _global_dp_buffer_len: int + _local_dp_buffer_len: int + + @classmethod + def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device): + cls._hidden_size = hidden_size + cls._dtype = dtype + cls._device = device + + @classmethod + def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int): + cls._global_dp_buffer_len = global_dp_buffer_len + cls._local_dp_buffer_len = local_dp_buffer_len + + @classmethod + def get_global_dp_buffer(cls) -> torch.Tensor: + return torch.empty( + (cls._global_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + + @classmethod + def get_local_dp_buffer(cls) -> torch.Tensor: + return torch.empty( + (cls._local_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + + @classmethod + def get_global_dp_buffer_len(cls) -> int: + return cls._global_dp_buffer_len + + @classmethod + def get_local_dp_buffer_len(cls) -> int: + return cls._local_dp_buffer_len + + +def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int): + _DpGatheredBufferWrapper.set_dp_buffer_len( + global_dp_buffer_len, local_dp_buffer_len + ) + + +def get_global_dp_buffer() -> torch.Tensor: + return _DpGatheredBufferWrapper.get_global_dp_buffer() + + +def get_local_dp_buffer() -> torch.Tensor: + return _DpGatheredBufferWrapper.get_local_dp_buffer() + + +def get_global_dp_buffer_len() -> int: + return _DpGatheredBufferWrapper.get_global_dp_buffer_len() + + +def get_local_dp_buffer_len() -> int: + return _DpGatheredBufferWrapper.get_local_dp_buffer_len() + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 @@ -89,18 +160,24 @@ def compute_dp_attention_local_info( def initialize_dp_attention( - enable_dp_attention: bool, - tp_rank: int, - tp_size: int, - dp_size: int, - moe_dense_tp_size: int, - pp_size: int, + server_args: ServerArgs, + model_config: ModelConfig, ): global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE - global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK + global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP + enable_dp_attention = server_args.enable_dp_attention + tp_size = server_args.tp_size + dp_size = server_args.dp_size + moe_dense_tp_size = server_args.moe_dense_tp_size + pp_size = server_args.pp_size + + tp_rank = get_tensor_model_parallel_rank() + + _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention + _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info( enable_dp_attention, tp_rank, tp_size, dp_size ) @@ -135,38 +212,48 @@ def initialize_dp_attention( group_name="attention_tp", ) + _DpGatheredBufferWrapper.set_metadata( + hidden_size=model_config.hidden_size, + dtype=model_config.dtype, + device=torch.device("cuda"), + ) -def get_attention_tp_group(): + +def is_dp_attention_enabled() -> bool: + return _ENABLE_DP_ATTENTION_FLAG + + +def get_attention_tp_group() -> GroupCoordinator: assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" return _ATTN_TP_GROUP -def get_attention_tp_rank(): +def get_attention_tp_rank() -> int: assert _ATTN_TP_RANK is not None, "dp attention not initialized!" return _ATTN_TP_RANK -def get_attention_tp_size(): +def get_attention_tp_size() -> int: assert _ATTN_TP_SIZE is not None, "dp attention not initialized!" return _ATTN_TP_SIZE -def get_attention_dp_rank(): +def get_attention_dp_rank() -> int: assert _ATTN_DP_RANK is not None, "dp attention not initialized!" return _ATTN_DP_RANK -def get_attention_dp_size(): +def get_attention_dp_size() -> int: assert _ATTN_DP_SIZE is not None, "dp attention not initialized!" return _ATTN_DP_SIZE -def get_local_attention_dp_rank(): +def get_local_attention_dp_rank() -> int: assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!" return _LOCAL_ATTN_DP_RANK -def get_local_attention_dp_size(): +def get_local_attention_dp_size() -> int: assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!" return _LOCAL_ATTN_DP_SIZE diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 3384f5efa..711aba03f 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -27,7 +27,7 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_gather, ) from sglang.srt.layers.dp_attention import ( - DPPaddingMode, + DpPaddingMode, attn_tp_all_gather, attn_tp_all_gather_into_tensor, dp_gather_replicate, @@ -35,7 +35,9 @@ from sglang.srt.layers.dp_attention import ( get_attention_dp_rank, get_attention_dp_size, get_attention_tp_size, + get_global_dp_buffer, get_local_attention_dp_size, + set_dp_buffer_len, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -108,14 +110,12 @@ class LogitsMetadata: # The start position of local hidden states. dp_local_start_pos: Optional[torch.Tensor] = None dp_local_num_tokens: Optional[torch.Tensor] = None - gathered_buffer: Optional[torch.Tensor] = None - # Buffer to gather logits from all ranks. - forward_batch_gathered_buffer: Optional[torch.Tensor] = None + global_dp_buffer_len: Optional[int] = None # Number of tokens to sample per DP rank global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None # The gather mode for DP attention - dp_padding_mode: Optional[DPPaddingMode] = None + dp_padding_mode: Optional[DpPaddingMode] = None # for padding padded_static_len: int = -1 @@ -164,11 +164,10 @@ class LogitsMetadata: global_num_tokens_gpu=forward_batch.global_num_tokens_gpu, dp_local_start_pos=forward_batch.dp_local_start_pos, dp_local_num_tokens=forward_batch.dp_local_num_tokens, - gathered_buffer=forward_batch.gathered_buffer, - forward_batch_gathered_buffer=forward_batch.gathered_buffer, + global_dp_buffer_len=forward_batch.global_dp_buffer_len, global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu, global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, - dp_padding_mode=DPPaddingMode.SUM_LEN, + dp_padding_mode=DpPaddingMode.SUM_LEN, ) def compute_dp_attention_metadata(self): @@ -188,16 +187,11 @@ class LogitsMetadata: if self.global_num_tokens_for_logprob_cpu is not None: # create a smaller buffer to reduce peak memory usage - self.gathered_buffer = torch.empty( - ( - sum(self.global_num_tokens_for_logprob_cpu), - self.gathered_buffer.shape[1], - ), - dtype=self.gathered_buffer.dtype, - device=self.gathered_buffer.device, - ) + self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu) else: - self.gathered_buffer = torch.empty_like(self.gathered_buffer) + self.global_dp_buffer_len = self.global_dp_buffer_len + + set_dp_buffer_len(self.global_dp_buffer_len, self.dp_local_num_tokens) class LogitsProcessor(nn.Module): @@ -443,7 +437,7 @@ class LogitsProcessor(nn.Module): if self.do_tensor_parallel_all_gather_dp_attn: logits_metadata.compute_dp_attention_metadata() hidden_states, local_hidden_states = ( - logits_metadata.gathered_buffer, + get_global_dp_buffer(), hidden_states, ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 75644b588..cf4222cc7 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -6,7 +6,10 @@ import torch.distributed as dist from torch import nn from sglang.srt.distributed import get_tp_group -from sglang.srt.layers.dp_attention import get_attention_tp_group +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + is_dp_attention_enabled, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -32,7 +35,7 @@ class Sampler(nn.Module): self.use_nan_detection = global_server_args_dict["enable_nan_detection"] self.tp_sync_group = get_tp_group().device_group - if global_server_args_dict["enable_dp_attention"]: + if is_dp_attention_enabled(): self.tp_sync_group = get_attention_tp_group().device_group def forward( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 55b1a9ec7..8b1b11bdf 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -84,7 +84,6 @@ GLOBAL_SERVER_ARGS_KEYS = [ "device", "disable_chunked_prefix_cache", "disable_radix_cache", - "enable_dp_attention", "enable_two_batch_overlap", "tbo_token_distribution_threshold", "enable_dp_lm_head", diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 537dab9eb..cc87910ac 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( ) from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.dp_attention import ( - DPPaddingMode, + DpPaddingMode, get_attention_tp_rank, get_attention_tp_size, + set_dp_buffer_len, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache @@ -349,30 +350,15 @@ class CudaGraphRunner: self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) - self.gathered_buffer = torch.zeros( - ( - self.max_num_token * self.dp_size, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) else: self.global_num_tokens_gpu = None self.global_num_tokens_for_logprob_gpu = None - self.gathered_buffer = None self.custom_mask = torch.ones( ( @@ -556,7 +542,7 @@ class CudaGraphRunner: device=input_ids.device, ) ) - gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] + global_dp_buffer_len = num_tokens * self.dp_size elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( @@ -572,9 +558,9 @@ class CudaGraphRunner: device=input_ids.device, ) ) - gathered_buffer = self.gathered_buffer[:num_tokens] + global_dp_buffer_len = num_tokens else: - gathered_buffer = None + global_dp_buffer_len = None spec_info = self.get_spec_info(num_tokens) if self.capture_hidden_mode != CaptureHiddenMode.FULL: @@ -607,8 +593,8 @@ class CudaGraphRunner: positions=positions, global_num_tokens_gpu=self.global_num_tokens_gpu, global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, - dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), - gathered_buffer=gathered_buffer, + dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), + global_dp_buffer_len=global_dp_buffer_len, mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, @@ -637,6 +623,7 @@ class CudaGraphRunner: def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + set_dp_buffer_len(global_dp_buffer_len, num_tokens) kwargs = {} if ( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index c019d7e3f..da2d81fc5 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -40,9 +40,10 @@ import triton.language as tl from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.layers.dp_attention import ( - DPPaddingMode, + DpPaddingMode, get_attention_dp_rank, get_attention_tp_size, + set_dp_buffer_len, ) from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.utils import ( @@ -274,13 +275,13 @@ class ForwardBatch: global_num_tokens_for_logprob_cpu: Optional[List[int]] = None global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None # The padding mode for DP attention - dp_padding_mode: Optional[DPPaddingMode] = None + dp_padding_mode: Optional[DpPaddingMode] = None # for extend, local start pos and num tokens is different in logits processor # this will be computed in get_dp_local_info # this will be recomputed in LogitsMetadata.from_forward_batch dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime - gathered_buffer: Optional[torch.Tensor] = None + global_dp_buffer_len: Optional[int] = None is_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False global_forward_mode: Optional[ForwardMode] = None @@ -628,7 +629,7 @@ class ForwardBatch: (global_num_tokens[i] - 1) // attn_tp_size + 1 ) * attn_tp_size - dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens) + dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens) self.dp_padding_mode = dp_padding_mode if dp_padding_mode.is_max_len(): @@ -642,17 +643,14 @@ class ForwardBatch: else: buffer_len = sum(global_num_tokens) - self.gathered_buffer = torch.zeros( - (buffer_len, model_runner.model_config.hidden_size), - dtype=model_runner.dtype, - device=model_runner.device, - ) - if len(global_num_tokens) > 1: num_tokens = global_num_tokens[get_attention_dp_rank()] else: num_tokens = global_num_tokens[0] + self.global_dp_buffer_len = buffer_len + set_dp_buffer_len(buffer_len, num_tokens) + bs = self.batch_size if self.forward_mode.is_decode(): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bb67c79f3..5222bff0a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -603,12 +603,8 @@ class ModelRunner: duplicate_tp_group=self.server_args.enable_pdmux, ) initialize_dp_attention( - enable_dp_attention=self.server_args.enable_dp_attention, - tp_rank=self.tp_rank, - tp_size=self.tp_size, - dp_size=self.server_args.dp_size, - moe_dense_tp_size=self.server_args.moe_dense_tp_size, - pp_size=self.server_args.pp_size, + server_args=self.server_args, + model_config=self.model_config, ) min_per_gpu_memory = get_available_gpu_memory( diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index e61dadadc..5b1ae6e69 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -22,6 +22,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -56,7 +57,7 @@ class DeepseekModelNextN(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 235718ded..90efc4067 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, + is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -1797,7 +1798,6 @@ class DeepseekV2DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.speculative_algorithm = global_server_args_dict["speculative_algorithm"] self.layer_id = layer_id self.is_nextn = is_nextn @@ -1917,7 +1917,9 @@ class DeepseekV2DecoderLayer(nn.Module): should_allreduce_fusion = ( self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) - and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle()) + and not ( + is_dp_attention_enabled() and self.speculative_algorithm.is_eagle() + ) and not self.is_nextn ) @@ -2047,7 +2049,7 @@ class DeepseekV2Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), ) self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.layers = nn.ModuleList( diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 7727e9605..e0f0b373d 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -40,6 +40,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, + is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -634,7 +635,6 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ) rms_norm_eps = config.rms_norm_eps attention_bias = config.attention_bias - self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.layer_id = layer_id self.self_attn = Glm4MoeAttention( hidden_size=self.hidden_size, @@ -744,7 +744,7 @@ class Glm4MoeModel(DeepseekV2Model): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), ) self.alt_stream = torch.cuda.Stream() if _is_cuda else None self.layers = nn.ModuleList( diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py index 1a0793d8a..399f0f4e0 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -22,6 +22,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -56,7 +57,7 @@ class Glm4MoeModelNextN(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), ) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index e15bc5dc2..b5057fb3e 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -41,6 +41,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, + is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -565,7 +566,7 @@ class GptOssModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), ) else: diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index f9966351f..c0a2be43d 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -32,6 +32,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, + is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -45,7 +46,6 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, @@ -466,7 +466,7 @@ class Llama4Model(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("embed_tokens", prefix), - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), ) self.layers = make_layers( config.num_hidden_layers, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 556a5bb8f..2b1ea57fd 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -27,6 +27,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, @@ -273,7 +273,7 @@ class Qwen2Model(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), ) else: diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 2af1e919d..da7936c4d 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, get_local_attention_dp_size, + is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -420,7 +421,7 @@ class Qwen2MoeModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), ) else: diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index b0c2e0a81..64bb2183c 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes -from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -437,7 +441,7 @@ class Step3TextModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - enable_tp=not global_server_args_dict["enable_dp_attention"], + enable_tp=not is_dp_attention_enabled(), prefix=add_prefix("embed_tokens", prefix), ) diff --git a/python/sglang/srt/operations.py b/python/sglang/srt/operations.py index 0a8c118df..f850bcd25 100644 --- a/python/sglang/srt/operations.py +++ b/python/sglang/srt/operations.py @@ -1,10 +1,17 @@ +from __future__ import annotations + import os from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Callable, Dict, Generator, List, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union import torch +from sglang.srt.layers.dp_attention import set_dp_buffer_len + +if TYPE_CHECKING: + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0"))) if _ENABLE_PROFILE: @@ -66,18 +73,26 @@ Stage = List[ExecutionOperation] class _StageExecutor: - def __init__(self, debug_name: str, stages: List[Stage], inputs): + def __init__(self, debug_name: str, stages: List[Stage], inputs: dict): self._debug_name = debug_name self._stages = stages self._index = 0 self._stage_state = _StateDict() self._stage_output = inputs + # handling DP attention + forward_batch: ForwardBatch = inputs["forward_batch"] + self._global_dp_buffer_len = forward_batch.global_dp_buffer_len + self._local_dp_buffer_len = forward_batch.input_ids.shape[0] + def next(self): assert not self.done stage = self._stages[self._index] + if self._global_dp_buffer_len is not None: + set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len) + with _annotate_region(debug_name=f"{self._debug_name}{self._index}"): for op in stage: with _annotate_region(debug_name=op.debug_name): diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 8cc324158..e824fb1ae 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable import torch -from sglang.srt.layers.dp_attention import DPPaddingMode +from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, @@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner: self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) - self.gathered_buffer = torch.zeros( - ( - self.max_num_token * self.dp_size, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) else: self.global_num_tokens_gpu = None self.global_num_tokens_for_logprob_gpu = None - self.gathered_buffer = None # Capture try: @@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner: ) ) global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] + global_dp_buffer_len = num_tokens * self.dp_size global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( @@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner: ) ) global_num_tokens = self.global_num_tokens_gpu - gathered_buffer = self.gathered_buffer[:num_tokens] + global_dp_buffer_len = num_tokens global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu else: global_num_tokens = None - gathered_buffer = None + global_dp_buffer_len = None global_num_tokens_for_logprob = None spec_info = EagleDraftInput( @@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner: return_logprob=False, positions=positions, global_num_tokens_gpu=global_num_tokens, - dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), - gathered_buffer=gathered_buffer, + dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), + global_dp_buffer_len=global_dp_buffer_len, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=( @@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner: def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + set_dp_buffer_len(global_dp_buffer_len, num_tokens) # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 08d823a0b..4f4403fee 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable import torch -from sglang.srt.layers.dp_attention import DPPaddingMode +from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, CudaGraphRunner, @@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner: self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) - self.gathered_buffer = torch.zeros( - ( - self.max_num_token * self.dp_size, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) - self.gathered_buffer = torch.zeros( - ( - self.max_num_token, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) else: self.global_num_tokens_gpu = None self.global_num_tokens_for_logprob_gpu = None - self.gathered_buffer = None if hasattr( self.model_runner.model_config.hf_config, "draft_vocab_size" @@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner: device=self.input_ids.device, ) ) - gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size] + global_dp_buffer_len = num_tokens * self.dp_size elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( @@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner: device=self.input_ids.device, ) ) - gathered_buffer = self.gathered_buffer[:num_tokens] + global_dp_buffer_len = num_tokens else: - gathered_buffer = None + global_dp_buffer_len = None spec_info = EagleDraftInput( hidden_states=hidden_states, @@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner: positions=positions, global_num_tokens_gpu=self.global_num_tokens_gpu, global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, - dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(), - gathered_buffer=gathered_buffer, + dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), + global_dp_buffer_len=global_dp_buffer_len, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=CaptureHiddenMode.LAST, @@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner: def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + set_dp_buffer_len(global_dp_buffer_len, num_tokens) # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 23580a463..223ff0cbe 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -678,16 +678,12 @@ class TboForwardBatchPreparer: # TODO improve, e.g. unify w/ `init_raw` if ( global_server_args_dict["moe_dense_tp_size"] == 1 - and batch.gathered_buffer is not None + and batch.global_dp_buffer_len is not None ): sum_len = end_token_index - start_token_index - gathered_buffer = torch.zeros( - (sum_len, batch.gathered_buffer.shape[1]), - dtype=batch.gathered_buffer.dtype, - device=batch.gathered_buffer.device, - ) + global_dp_buffer_len = sum_len else: - gathered_buffer = None + global_dp_buffer_len = None output_dict.update( dict( @@ -706,7 +702,7 @@ class TboForwardBatchPreparer: global_num_tokens_gpu=None, global_num_tokens_cpu=None, dp_padding_mode=None, - gathered_buffer=gathered_buffer, + global_dp_buffer_len=global_dp_buffer_len, global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_cpu=None, sampling_info=None,