Improve DP attention (#4390)
Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
@@ -86,6 +90,27 @@ def get_attention_dp_size():
|
||||
return _DP_SIZE
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_dp_size():
|
||||
"""Patch the tp group temporarily until this function ends.
|
||||
|
||||
This method is for draft workers of speculative decoding to run draft model
|
||||
with different tp degree from that of target model workers.
|
||||
|
||||
Args:
|
||||
tp_group (GroupCoordinator): the tp group coordinator
|
||||
"""
|
||||
global _DP_SIZE
|
||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||
|
||||
old_dp_size = _DP_SIZE
|
||||
_DP_SIZE = 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_DP_SIZE = old_dp_size
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
dp_rank = get_attention_dp_rank()
|
||||
|
||||
@@ -159,7 +184,8 @@ def dp_gather(
|
||||
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||
):
|
||||
assert (
|
||||
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
|
||||
global_tokens.untyped_storage().data_ptr()
|
||||
!= local_tokens.untyped_storage().data_ptr()
|
||||
), "aliasing between global_tokens and local_tokens not allowed"
|
||||
memcpy_triton(
|
||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||
@@ -174,8 +200,9 @@ def dp_gather(
|
||||
torch.ops.sglang.inplace_all_reduce(
|
||||
global_tokens, group_name=get_tp_group().unique_name
|
||||
)
|
||||
|
||||
else:
|
||||
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
||||
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
||||
|
||||
|
||||
def dp_scatter(
|
||||
@@ -186,6 +213,7 @@ def dp_scatter(
|
||||
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
||||
# since local_tokens may be padded for cuda graph
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
|
||||
local_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
|
||||
@@ -23,6 +23,7 @@ import triton.language as tl
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
|
||||
class DataParallelController:
|
||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||
|
||||
def __init__(self, server_args, port_args) -> None:
|
||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
||||
# Parse args
|
||||
self.max_total_num_tokens = None
|
||||
self.server_args = server_args
|
||||
|
||||
@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
# Handle DP attention
|
||||
if self.server_args.enable_dp_attention:
|
||||
ret = self.prepare_dp_attn_batch(ret)
|
||||
ret, _ = self.prepare_dp_attn_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
num_tokens = 0
|
||||
global_num_tokens_for_logprob = 0
|
||||
elif local_batch.forward_mode.is_decode():
|
||||
num_tokens = local_batch.batch_size()
|
||||
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
|
||||
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
|
||||
global_num_tokens_for_logprob = num_tokens
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
global_num_tokens_for_logprob = sum(
|
||||
[
|
||||
# We should have at least 1 token for sample in every case.
|
||||
max(extend_len - logprob_start_len, 1)
|
||||
for logprob_start_len, extend_len in zip(
|
||||
local_batch.extend_logprob_start_lens, local_batch.extend_lens
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
||||
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
||||
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
|
||||
can_cuda_graph = 1
|
||||
else:
|
||||
can_cuda_graph = 0
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(sang): Support cuda graph when idle batch is there.
|
||||
if local_batch is None or local_batch.forward_mode.is_idle():
|
||||
can_cuda_graph = 0
|
||||
|
||||
is_extend_in_batch = (
|
||||
local_batch.forward_mode.is_extend() if local_batch else False
|
||||
)
|
||||
local_info = torch.tensor(
|
||||
[
|
||||
num_tokens,
|
||||
can_cuda_graph,
|
||||
global_num_tokens_for_logprob,
|
||||
is_extend_in_batch,
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
global_info = torch.empty(
|
||||
(self.server_args.dp_size, self.attn_tp_size, 4),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_num_tokens,
|
||||
local_num_tokens,
|
||||
global_info.flatten(),
|
||||
local_info,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
global_num_tokens = global_info[:, 0, 0].tolist()
|
||||
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
||||
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
||||
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
||||
|
||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||
if local_batch is None and max(global_num_tokens) > 0:
|
||||
local_batch = self.get_idle_batch()
|
||||
|
||||
if local_batch is not None:
|
||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||
local_batch.global_num_tokens = global_num_tokens
|
||||
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not self.server_args.disable_cuda_graph:
|
||||
forward_mode_state = torch.tensor(
|
||||
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
forward_mode_state,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
||||
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||
|
||||
return local_batch
|
||||
return local_batch, any(is_extend_in_batch)
|
||||
|
||||
def get_idle_batch(self):
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
|
||||
@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
|
||||
@@ -236,7 +237,7 @@ class CudaGraphRunner:
|
||||
if self.enable_dp_attention:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.dp_size,
|
||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
@@ -276,13 +277,12 @@ class CudaGraphRunner:
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention:
|
||||
min_num_tokens, max_num_tokens = min(
|
||||
forward_batch.global_num_tokens_cpu
|
||||
), max(forward_batch.global_num_tokens_cpu)
|
||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
||||
total_global_tokens in self.graphs
|
||||
if self.disable_padding
|
||||
else max_num_tokens <= self.max_bs
|
||||
else total_global_tokens <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
@@ -304,6 +304,9 @@ class CudaGraphRunner:
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
|
||||
)
|
||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||
capture_range = (
|
||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||
@@ -311,6 +314,16 @@ class CudaGraphRunner:
|
||||
else reversed(self.capture_bs)
|
||||
)
|
||||
for bs in capture_range:
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
avail_mem = get_available_gpu_memory(
|
||||
self.model_runner.device,
|
||||
self.model_runner.gpu_id,
|
||||
empty_cache=False,
|
||||
)
|
||||
capture_range.set_description(
|
||||
f"Capturing batches ({avail_mem=:.2f} GB)"
|
||||
)
|
||||
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
@@ -345,8 +358,18 @@ class CudaGraphRunner:
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
global_num_tokens = [bs] * self.tp_size
|
||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device,
|
||||
)
|
||||
)
|
||||
global_num_tokens = self.global_num_tokens_gpu
|
||||
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||
else:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
@@ -371,7 +394,7 @@ class CudaGraphRunner:
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
global_num_tokens_cpu=global_num_tokens,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
@@ -392,6 +415,9 @@ class CudaGraphRunner:
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
return logits_output.next_token_logits, logits_output.hidden_states
|
||||
|
||||
@@ -426,7 +452,7 @@ class CudaGraphRunner:
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
self.capture()
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
||||
def replay_prepare(self, forward_batch: ForwardBatch):
|
||||
self.recapture_if_needed(forward_batch)
|
||||
|
||||
raw_bs = forward_batch.batch_size
|
||||
@@ -435,7 +461,7 @@ class CudaGraphRunner:
|
||||
# Pad
|
||||
if self.enable_dp_attention:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
|
||||
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
@@ -459,6 +485,8 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||
|
||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
||||
@@ -475,14 +503,29 @@ class CudaGraphRunner:
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
)
|
||||
|
||||
# Store fields
|
||||
self.raw_bs = raw_bs
|
||||
self.raw_num_token = raw_num_token
|
||||
self.bs = bs
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
||||
if not skip_attn_backend_init:
|
||||
self.replay_prepare(forward_batch)
|
||||
else:
|
||||
# In speculative decoding, these two fields are still needed.
|
||||
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
||||
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
next_token_logits, hidden_states = self.output_buffers[bs]
|
||||
self.graphs[self.bs].replay()
|
||||
next_token_logits, hidden_states = self.output_buffers[self.bs]
|
||||
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits[:raw_num_token],
|
||||
next_token_logits=next_token_logits[: self.raw_num_token],
|
||||
hidden_states=(
|
||||
hidden_states[:raw_num_token] if hidden_states is not None else None
|
||||
hidden_states[: self.raw_num_token]
|
||||
if hidden_states is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
return logits_output
|
||||
|
||||
@@ -38,7 +38,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -263,15 +263,24 @@ class ForwardBatch:
|
||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||
)
|
||||
|
||||
# For DP attention
|
||||
if batch.global_num_tokens is not None:
|
||||
ret.global_num_tokens_cpu = batch.global_num_tokens
|
||||
max_len = max(ret.global_num_tokens_cpu)
|
||||
ret.global_num_tokens_gpu = torch.tensor(
|
||||
batch.global_num_tokens, dtype=torch.int64
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
||||
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
||||
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
sum_len = sum(batch.global_num_tokens)
|
||||
ret.gathered_buffer = torch.zeros(
|
||||
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
||||
(sum_len, model_runner.model_config.hidden_size),
|
||||
dtype=model_runner.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if ret.forward_mode.is_idle():
|
||||
ret.positions = torch.empty((0,), device=device)
|
||||
return ret
|
||||
|
||||
@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
dp_gather,
|
||||
dp_scatter,
|
||||
get_attention_dp_size,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.dp_size = get_attention_dp_size()
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
assert num_heads % attn_tp_size == 0
|
||||
self.num_local_heads = num_heads // attn_tp_size
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
reduce_results=reduce_results,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q = self.q_a_layernorm(q)
|
||||
@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
use_dp=False,
|
||||
reduce_results: bool = True,
|
||||
layer_id: int = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.dp_size = get_attention_dp_size()
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
||||
assert num_heads % attn_tp_size == 0
|
||||
self.num_local_heads = num_heads // attn_tp_size
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if use_dp:
|
||||
# For data parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ReplicatedLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
self.kv_b_proj = ReplicatedLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = ReplicatedLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
# For tensor parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
else:
|
||||
# For tensor parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
@@ -542,38 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
|
||||
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
|
||||
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||
if self.enable_flashinfer_mla:
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
return (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
return (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states
|
||||
|
||||
def no_absorb() -> bool:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
return (
|
||||
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
return (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
|
||||
if no_absorb():
|
||||
if self.no_absorb(forward_batch):
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
if _is_hip:
|
||||
if (
|
||||
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
self.rocm_fused_decode_mla
|
||||
and forward_batch.forward_mode.is_decode()
|
||||
):
|
||||
return self.forward_absorb_fused_mla_rope(
|
||||
@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def all_gather(
|
||||
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
||||
):
|
||||
all_lens = forward_batch.global_num_tokens_cpu
|
||||
max_len = max(forward_batch.global_num_tokens_cpu)
|
||||
|
||||
if world_size == 1:
|
||||
return input_tensor, 0, all_lens[0]
|
||||
|
||||
padded_tensor = torch.nn.functional.pad(
|
||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
||||
)
|
||||
|
||||
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
||||
|
||||
gathered_tensors = torch.concat(
|
||||
[
|
||||
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
||||
for i in range(world_size)
|
||||
]
|
||||
)
|
||||
|
||||
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
||||
end_index = start_index + all_lens[rank]
|
||||
|
||||
return gathered_tensors, start_index, end_index
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.enable_dp_attention = (
|
||||
not global_server_args_dict["disable_mla"]
|
||||
and global_server_args_dict["enable_dp_attention"]
|
||||
)
|
||||
if self.enable_dp_attention:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_group = get_tp_group()
|
||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||
self.layer_id = layer_id
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
config=config,
|
||||
@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
use_dp=self.enable_dp_attention,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
else:
|
||||
@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
if is_nextn or (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
# Scatter
|
||||
if self.dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather(
|
||||
hidden_states, local_hidden_states, forward_batch, self.layer_id
|
||||
)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
# Fully Connected
|
||||
if self.enable_dp_attention:
|
||||
hidden_states, start_idx, end_idx = all_gather(
|
||||
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = hidden_states[start_idx:end_idx]
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Gather
|
||||
if self.dp_size != 1:
|
||||
input_ids, local_input_ids = (
|
||||
torch.empty(
|
||||
(forward_batch.gathered_buffer.shape[0],),
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
),
|
||||
input_ids,
|
||||
)
|
||||
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
@@ -1059,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.lm_head = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
|
||||
if self.dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
@@ -262,14 +262,14 @@ class ServerArgs:
|
||||
|
||||
# Data parallelism attention
|
||||
if self.enable_dp_attention:
|
||||
self.dp_size = self.tp_size
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
assert (
|
||||
self.dp_size > 1
|
||||
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
||||
logger.warning(
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
||||
)
|
||||
|
||||
# Speculative Decoding
|
||||
|
||||
Reference in New Issue
Block a user