From 8e66fbecee8b02becc7efd21ad6ffa4044f8a931 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 13 Mar 2025 08:23:56 -0700 Subject: [PATCH] Improve DP attention (#4390) Co-authored-by: dhou-xai Co-authored-by: SangBin Cho --- python/sglang/srt/layers/dp_attention.py | 32 +- python/sglang/srt/layers/logits_processor.py | 1 + .../srt/managers/data_parallel_controller.py | 2 +- python/sglang/srt/managers/scheduler.py | 69 +++- .../srt/model_executor/cuda_graph_runner.py | 75 +++- .../srt/model_executor/forward_batch_info.py | 17 +- python/sglang/srt/models/deepseek_v2.py | 363 +++++++++--------- python/sglang/srt/server_args.py | 10 +- test/srt/test_dp_attention.py | 2 + 9 files changed, 345 insertions(+), 226 deletions(-) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index f8b756f52..42d4a1457 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -1,6 +1,8 @@ from __future__ import annotations import functools +import logging +from contextlib import contextmanager from typing import TYPE_CHECKING, Union import torch @@ -14,6 +16,8 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_reduce, ) +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -86,6 +90,27 @@ def get_attention_dp_size(): return _DP_SIZE +@contextmanager +def disable_dp_size(): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _DP_SIZE + assert _DP_SIZE is not None, "dp attention not initialized!" + + old_dp_size = _DP_SIZE + _DP_SIZE = 1 + try: + yield + finally: + _DP_SIZE = old_dp_size + + def get_dp_local_info(forward_batch: ForwardBatch): dp_rank = get_attention_dp_rank() @@ -159,7 +184,8 @@ def dp_gather( layer_id != "embedding" or get_attention_tp_rank() == 0 ): assert ( - global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr() + global_tokens.untyped_storage().data_ptr() + != local_tokens.untyped_storage().data_ptr() ), "aliasing between global_tokens and local_tokens not allowed" memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False @@ -174,8 +200,9 @@ def dp_gather( torch.ops.sglang.inplace_all_reduce( global_tokens, group_name=get_tp_group().unique_name ) + else: - global_tokens = tensor_model_parallel_all_reduce(global_tokens) + global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens) def dp_scatter( @@ -186,6 +213,7 @@ def dp_scatter( # local_num_tokens is not necessarily the same as local_tokens.shape[0], # since local_tokens may be padded for cuda graph local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + local_tokens.fill_(0) assert local_tokens.is_contiguous() assert global_tokens.is_contiguous() diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index a0e38b022..b398e052d 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,6 +23,7 @@ import triton.language as tl from torch import nn from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index f1d669fc8..627d72c7b 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum): class DataParallelController: """A controller that dispatches requests to multiple data parallel workers.""" - def __init__(self, server_args, port_args) -> None: + def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None: # Parse args self.max_total_num_tokens = None self.server_args = server_args diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 68bef7e08..c91905f5f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin): # Handle DP attention if self.server_args.enable_dp_attention: - ret = self.prepare_dp_attn_batch(ret) + ret, _ = self.prepare_dp_attn_batch(ret) return ret @@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin): # Check if other DP workers have running batches if local_batch is None: num_tokens = 0 + global_num_tokens_for_logprob = 0 elif local_batch.forward_mode.is_decode(): num_tokens = local_batch.batch_size() + if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle(): + num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens + global_num_tokens_for_logprob = num_tokens else: num_tokens = local_batch.extend_num_tokens + global_num_tokens_for_logprob = sum( + [ + # We should have at least 1 token for sample in every case. + max(extend_len - logprob_start_len, 1) + for logprob_start_len, extend_len in zip( + local_batch.extend_logprob_start_lens, local_batch.extend_lens + ) + ] + ) - local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) - global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) + if local_batch is None or local_batch.forward_mode.is_decode_or_idle(): + can_cuda_graph = 1 + else: + can_cuda_graph = 0 + + if not self.spec_algorithm.is_none(): + # TODO(sang): Support cuda graph when idle batch is there. + if local_batch is None or local_batch.forward_mode.is_idle(): + can_cuda_graph = 0 + + is_extend_in_batch = ( + local_batch.forward_mode.is_extend() if local_batch else False + ) + local_info = torch.tensor( + [ + num_tokens, + can_cuda_graph, + global_num_tokens_for_logprob, + is_extend_in_batch, + ], + dtype=torch.int64, + ) + global_info = torch.empty( + (self.server_args.dp_size, self.attn_tp_size, 4), + dtype=torch.int64, + ) torch.distributed.all_gather_into_tensor( - global_num_tokens, - local_num_tokens, + global_info.flatten(), + local_info, group=self.tp_cpu_group, ) + global_num_tokens = global_info[:, 0, 0].tolist() + can_cuda_graph = min(global_info[:, 0, 1].tolist()) + global_num_tokens_for_logprob = global_info[:, 0, 2].tolist() + is_extend_in_batch = global_info[:, 0, 3].tolist() - if local_batch is None and global_num_tokens.max().item() > 0: + if local_batch is None and max(global_num_tokens) > 0: local_batch = self.get_idle_batch() if local_batch is not None: - local_batch.global_num_tokens = global_num_tokens.tolist() + local_batch.global_num_tokens = global_num_tokens + local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob # Check forward mode for cuda graph if not self.server_args.disable_cuda_graph: - forward_mode_state = torch.tensor( - (1 if local_batch.forward_mode.is_decode_or_idle() else 0), - dtype=torch.int32, - ) - torch.distributed.all_reduce( - forward_mode_state, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_cpu_group, - ) - local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1 + local_batch.can_run_dp_cuda_graph = can_cuda_graph - return local_batch + return local_batch, any(is_extend_in_batch) def get_idle_batch(self): idle_batch = ScheduleBatch.init_new( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 99f54b7f9..99ef14c2c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_available_gpu_memory, is_hip _is_hip = is_hip() @@ -174,6 +174,7 @@ class CudaGraphRunner: self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.tp_size = model_runner.server_args.tp_size self.dp_size = model_runner.server_args.dp_size @@ -236,7 +237,7 @@ class CudaGraphRunner: if self.enable_dp_attention: self.gathered_buffer = torch.zeros( ( - self.max_bs * self.dp_size, + self.max_bs * self.dp_size * self.num_tokens_per_bs, self.model_runner.model_config.hidden_size, ), dtype=self.model_runner.dtype, @@ -276,13 +277,12 @@ class CudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.enable_dp_attention: - min_num_tokens, max_num_tokens = min( - forward_batch.global_num_tokens_cpu - ), max(forward_batch.global_num_tokens_cpu) + total_global_tokens = sum(forward_batch.global_num_tokens_cpu) + is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( - (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs) + total_global_tokens in self.graphs if self.disable_padding - else max_num_tokens <= self.max_bs + else total_global_tokens <= self.max_bs ) else: is_bs_supported = ( @@ -304,6 +304,9 @@ class CudaGraphRunner: def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream + avail_mem = get_available_gpu_memory( + self.model_runner.device, self.model_runner.gpu_id, empty_cache=False + ) # Reverse the order to enable better memory sharing across cuda graphs. capture_range = ( tqdm.tqdm(list(reversed(self.capture_bs))) @@ -311,6 +314,16 @@ class CudaGraphRunner: else reversed(self.capture_bs) ) for bs in capture_range: + if get_tensor_model_parallel_rank() == 0: + avail_mem = get_available_gpu_memory( + self.model_runner.device, + self.model_runner.gpu_id, + empty_cache=False, + ) + capture_range.set_description( + f"Capturing batches ({avail_mem=:.2f} GB)" + ) + with patch_model( self.model_runner.model, bs in self.compile_bs, @@ -345,8 +358,18 @@ class CudaGraphRunner: mrope_positions = self.mrope_positions[:, :bs] if self.enable_dp_attention: - global_num_tokens = [bs] * self.tp_size - gathered_buffer = self.gathered_buffer[: bs * self.tp_size] + self.global_num_tokens_gpu.copy_( + torch.tensor( + [ + num_tokens // self.dp_size + (i < bs % self.dp_size) + for i in range(self.dp_size) + ], + dtype=torch.int32, + device=input_ids.device, + ) + ) + global_num_tokens = self.global_num_tokens_gpu + gathered_buffer = self.gathered_buffer[:num_tokens] else: global_num_tokens = None gathered_buffer = None @@ -371,7 +394,7 @@ class CudaGraphRunner: encoder_lens=encoder_lens, return_logprob=False, positions=positions, - global_num_tokens_cpu=global_num_tokens, + global_num_tokens_gpu=global_num_tokens, gathered_buffer=gathered_buffer, mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, @@ -392,6 +415,9 @@ class CudaGraphRunner: # Run and capture def run_once(): + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + logits_output = forward(input_ids, forward_batch.positions, forward_batch) return logits_output.next_token_logits, logits_output.hidden_states @@ -426,7 +452,7 @@ class CudaGraphRunner: self.capture_hidden_mode = hidden_mode_from_spec_info self.capture() - def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): + def replay_prepare(self, forward_batch: ForwardBatch): self.recapture_if_needed(forward_batch) raw_bs = forward_batch.batch_size @@ -435,7 +461,7 @@ class CudaGraphRunner: # Pad if self.enable_dp_attention: index = bisect.bisect_left( - self.capture_bs, max(forward_batch.global_num_tokens_cpu) + self.capture_bs, sum(forward_batch.global_num_tokens_cpu) ) else: index = bisect.bisect_left(self.capture_bs, raw_bs) @@ -459,6 +485,8 @@ class CudaGraphRunner: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + if self.enable_dp_attention: + self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) if hasattr(forward_batch.spec_info, "hidden_states"): self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states @@ -475,14 +503,29 @@ class CudaGraphRunner: seq_lens_cpu=self.seq_lens_cpu, ) + # Store fields + self.raw_bs = raw_bs + self.raw_num_token = raw_num_token + self.bs = bs + + def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): + if not skip_attn_backend_init: + self.replay_prepare(forward_batch) + else: + # In speculative decoding, these two fields are still needed. + self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) + self.positions[: self.raw_num_token].copy_(forward_batch.positions) + # Replay - self.graphs[bs].replay() - next_token_logits, hidden_states = self.output_buffers[bs] + self.graphs[self.bs].replay() + next_token_logits, hidden_states = self.output_buffers[self.bs] logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits[:raw_num_token], + next_token_logits=next_token_logits[: self.raw_num_token], hidden_states=( - hidden_states[:raw_num_token] if hidden_states is not None else None + hidden_states[: self.raw_num_token] + if hidden_states is not None + else None ), ) return logits_output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 11d90882b..b732b033e 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,7 +38,7 @@ import triton import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import get_compiler_backend, next_power_of_2 +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -263,15 +263,24 @@ class ForwardBatch: extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, ) + # For DP attention if batch.global_num_tokens is not None: ret.global_num_tokens_cpu = batch.global_num_tokens - max_len = max(ret.global_num_tokens_cpu) + ret.global_num_tokens_gpu = torch.tensor( + batch.global_num_tokens, dtype=torch.int64 + ).to(device, non_blocking=True) + + ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob + ret.global_num_tokens_for_logprob_gpu = torch.tensor( + batch.global_num_tokens_for_logprob, dtype=torch.int64 + ).to(device, non_blocking=True) + + sum_len = sum(batch.global_num_tokens) ret.gathered_buffer = torch.zeros( - (max_len * model_runner.tp_size, model_runner.model_config.hidden_size), + (sum_len, model_runner.model_config.hidden_size), dtype=model_runner.dtype, device=device, ) - if ret.forward_mode.is_idle(): ret.positions = torch.empty((0,), device=device) return ret diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d0ca14feb..f51107b82 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -26,15 +26,20 @@ from transformers import PretrainedConfig from vllm import _custom_ops as ops from sglang.srt.distributed import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( decode_attention_fwd_grouped_rope, ) +from sglang.srt.layers.dp_attention import ( + dp_gather, + dp_scatter, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module): max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, layer_id=None, + reduce_results: bool = True, prefix: str = "", ) -> None: super().__init__() @@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module): self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank + + self.dp_size = get_attention_dp_size() + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads // tp_size + assert num_heads % attn_tp_size == 0 + self.num_local_heads = num_heads // attn_tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module): bias=False, quant_config=quant_config, prefix=add_prefix("q_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, ) self.kv_a_proj_with_mqa = ReplicatedLinear( @@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module): bias=False, quant_config=quant_config, prefix=add_prefix("o_proj", prefix), + reduce_results=reduce_results, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, ) rope_scaling["rope_type"] = "deepseek_yarn" self.rotary_emb = get_rope_wrapper( @@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module): hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: + if hidden_states.shape[0] == 0: + assert ( + not self.o_proj.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return hidden_states + if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) @@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, - layer_id=None, - use_dp=False, + reduce_results: bool = True, + layer_id: int = None, prefix: str = "", ) -> None: super().__init__() @@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module): self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank + self.dp_size = get_attention_dp_size() + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + self.num_heads = num_heads - tp_size = get_tensor_model_parallel_world_size() - assert num_heads % tp_size == 0 - self.num_local_heads = num_heads if use_dp else num_heads // tp_size + assert num_heads % attn_tp_size == 0 + self.num_local_heads = num_heads // attn_tp_size self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - if use_dp: - # For data parallel attention - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( - self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_a_proj", prefix), - ) - self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ReplicatedLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_b_proj", prefix), - ) - else: - self.q_proj = ReplicatedLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_proj", prefix), - ) - self.kv_b_proj = ReplicatedLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=add_prefix("kv_b_proj", prefix), - ) - # O projection. - self.o_proj = ReplicatedLinear( - self.num_heads * self.v_head_dim, + # For tensor parallel attention + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( self.hidden_size, + self.q_lora_rank, bias=False, quant_config=quant_config, - prefix=add_prefix("o_proj", prefix), + prefix=add_prefix("q_a_proj", prefix), + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("q_b_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, ) else: - # For tensor parallel attention - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( - self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_a_proj", prefix), - ) - self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_b_proj", prefix), - ) - else: - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("q_proj", prefix), - ) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=add_prefix("kv_b_proj", prefix), - ) - # O projection. - self.o_proj = RowParallelLinear( - self.num_heads * self.v_head_dim, + self.q_proj = ColumnParallelLinear( self.hidden_size, + self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, - prefix=add_prefix("o_proj", prefix), + prefix=add_prefix("q_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, ) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=add_prefix("kv_b_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=add_prefix("o_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, @@ -542,38 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_vc = None self.w_scale = None + self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"] + self.flashinfer_mla_disable_ragged = global_server_args_dict[ + "flashinfer_mla_disable_ragged" + ] + self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" + + def no_absorb(self, forward_batch: ForwardBatch) -> bool: + if self.enable_flashinfer_mla: + # Flashinfer MLA: Do not absorb when enabling ragged prefill + return ( + not self.flashinfer_mla_disable_ragged + and forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + and forward_batch.extend_prefix_lens.sum() == 0 + ) + else: + # Triton: Use normal computation for prefill and use weight absorption for extend/decode + return ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + and forward_batch.extend_prefix_lens.sum() == 0 + ) + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: + if hidden_states.shape[0] == 0: + assert ( + not self.o_proj.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return hidden_states - def no_absorb() -> bool: - if global_server_args_dict["enable_flashinfer_mla"]: - # Flashinfer MLA: Do not absorb when enabling ragged prefill - return ( - not global_server_args_dict["flashinfer_mla_disable_ragged"] - and forward_batch.forward_mode.is_extend() - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - and forward_batch.extend_prefix_lens.sum() == 0 - ) - else: - # Triton: Use normal computation for prefill and use weight absorption for extend/decode - return ( - forward_batch.forward_mode.is_extend() - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - and forward_batch.extend_prefix_lens.sum() == 0 - ) - - if no_absorb(): + if self.no_absorb(forward_batch): return self.forward_normal(positions, hidden_states, forward_batch) else: if _is_hip: if ( - os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" + self.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode() ): return self.forward_absorb_fused_mla_rope( @@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module): return output -def all_gather( - input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group -): - all_lens = forward_batch.global_num_tokens_cpu - max_len = max(forward_batch.global_num_tokens_cpu) - - if world_size == 1: - return input_tensor, 0, all_lens[0] - - padded_tensor = torch.nn.functional.pad( - input_tensor, (0, 0, 0, max_len - input_tensor.shape[0]) - ) - - group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor) - - gathered_tensors = torch.concat( - [ - forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]] - for i in range(world_size) - ] - ) - - start_index = 0 if rank == 0 else sum(all_lens[:rank]) - end_index = start_index + all_lens[rank] - - return gathered_tensors, start_index, end_index - - class DeepseekV2DecoderLayer(nn.Module): def __init__( @@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.enable_dp_attention = ( - not global_server_args_dict["disable_mla"] - and global_server_args_dict["enable_dp_attention"] - ) - if self.enable_dp_attention: - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_group = get_tp_group() + self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] + self.layer_id = layer_id + self.dp_size = get_attention_dp_size() + if not global_server_args_dict["disable_mla"]: self.self_attn = DeepseekV2AttentionMLA( config=config, @@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module): max_position_embeddings=max_position_embeddings, quant_config=quant_config, layer_id=layer_id, - use_dp=self.enable_dp_attention, + reduce_results=False, prefix=add_prefix("self_attn", prefix), ) else: @@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module): max_position_embeddings=max_position_embeddings, quant_config=quant_config, layer_id=layer_id, + reduce_results=False, prefix=add_prefix("self_attn", prefix), ) + if is_nextn or ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace @@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], ) -> torch.Tensor: - # Self Attention - if not forward_batch.forward_mode.is_idle(): - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual + # Scatter + if self.dp_size != 1: + # important: forward batch.gathered_buffer is used both after scatter and after gather. + # be careful about this! + hidden_states, global_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, ) + dp_scatter(hidden_states, global_hidden_states, forward_batch) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Gather + if get_tensor_model_parallel_world_size() > 1: + # all gather and all reduce + if self.dp_size != 1: + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather( + hidden_states, local_hidden_states, forward_batch, self.layer_id + ) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) # Fully Connected - if self.enable_dp_attention: - hidden_states, start_idx, end_idx = all_gather( - hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group - ) - hidden_states = self.mlp(hidden_states) - hidden_states = hidden_states[start_idx:end_idx] - else: - hidden_states = self.mlp(hidden_states) - + hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module): ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dp_size = get_attention_dp_size() + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: + + # Gather + if self.dp_size != 1: + input_ids, local_input_ids = ( + torch.empty( + (forward_batch.gathered_buffer.shape[0],), + dtype=input_ids.dtype, + device=input_ids.device, + ), + input_ids, + ) + dp_gather(input_ids, local_input_ids, forward_batch, "embedding") + hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): @@ -1059,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module): self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) ) - if global_server_args_dict["enable_dp_attention"]: - self.lm_head = ReplicatedLinear( - config.hidden_size, - config.vocab_size, - bias=False, - prefix=add_prefix("lm_head", prefix), - ) - self.logits_processor = LogitsProcessor(config, skip_all_gather=True) - else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), - ) - self.logits_processor = LogitsProcessor(config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + self.dp_size = get_attention_dp_size() @torch.no_grad() def forward( @@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module): forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) + + if self.dp_size != 1: + # important: forward batch.gathered_buffer is used both after scatter and after gather. + # be careful about this! + hidden_states, global_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + dp_scatter(hidden_states, global_hidden_states, forward_batch) + return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dc927f096..fc3cb5cb2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -262,14 +262,14 @@ class ServerArgs: # Data parallelism attention if self.enable_dp_attention: - self.dp_size = self.tp_size - assert self.tp_size % self.dp_size == 0 - self.chunked_prefill_size = self.chunked_prefill_size // 2 self.schedule_conservativeness = self.schedule_conservativeness * 0.3 + assert ( + self.dp_size > 1 + ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size " + assert self.tp_size % self.dp_size == 0 + self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " - f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " - "Data parallel size is adjusted to be the same as tensor parallel size. " ) # Speculative Decoding diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 5de03a461..d24507ae2 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase): "--tp", "2", "--enable-dp-attention", + "--dp", + "2", ], )