Improve DP attention (#4390)
Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
|
|||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
@@ -86,6 +90,27 @@ def get_attention_dp_size():
|
|||||||
return _DP_SIZE
|
return _DP_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_dp_size():
|
||||||
|
"""Patch the tp group temporarily until this function ends.
|
||||||
|
|
||||||
|
This method is for draft workers of speculative decoding to run draft model
|
||||||
|
with different tp degree from that of target model workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tp_group (GroupCoordinator): the tp group coordinator
|
||||||
|
"""
|
||||||
|
global _DP_SIZE
|
||||||
|
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||||
|
|
||||||
|
old_dp_size = _DP_SIZE
|
||||||
|
_DP_SIZE = 1
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_DP_SIZE = old_dp_size
|
||||||
|
|
||||||
|
|
||||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||||
dp_rank = get_attention_dp_rank()
|
dp_rank = get_attention_dp_rank()
|
||||||
|
|
||||||
@@ -159,7 +184,8 @@ def dp_gather(
|
|||||||
layer_id != "embedding" or get_attention_tp_rank() == 0
|
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
|
global_tokens.untyped_storage().data_ptr()
|
||||||
|
!= local_tokens.untyped_storage().data_ptr()
|
||||||
), "aliasing between global_tokens and local_tokens not allowed"
|
), "aliasing between global_tokens and local_tokens not allowed"
|
||||||
memcpy_triton(
|
memcpy_triton(
|
||||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||||
@@ -174,8 +200,9 @@ def dp_gather(
|
|||||||
torch.ops.sglang.inplace_all_reduce(
|
torch.ops.sglang.inplace_all_reduce(
|
||||||
global_tokens, group_name=get_tp_group().unique_name
|
global_tokens, group_name=get_tp_group().unique_name
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
||||||
|
|
||||||
|
|
||||||
def dp_scatter(
|
def dp_scatter(
|
||||||
@@ -186,6 +213,7 @@ def dp_scatter(
|
|||||||
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
||||||
# since local_tokens may be padded for cuda graph
|
# since local_tokens may be padded for cuda graph
|
||||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||||
|
|
||||||
local_tokens.fill_(0)
|
local_tokens.fill_(0)
|
||||||
assert local_tokens.is_contiguous()
|
assert local_tokens.is_contiguous()
|
||||||
assert global_tokens.is_contiguous()
|
assert global_tokens.is_contiguous()
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import triton.language as tl
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
|
|||||||
class DataParallelController:
|
class DataParallelController:
|
||||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||||
|
|
||||||
def __init__(self, server_args, port_args) -> None:
|
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
||||||
# Parse args
|
# Parse args
|
||||||
self.max_total_num_tokens = None
|
self.max_total_num_tokens = None
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
|||||||
@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
|
|
||||||
# Handle DP attention
|
# Handle DP attention
|
||||||
if self.server_args.enable_dp_attention:
|
if self.server_args.enable_dp_attention:
|
||||||
ret = self.prepare_dp_attn_batch(ret)
|
ret, _ = self.prepare_dp_attn_batch(ret)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
# Check if other DP workers have running batches
|
# Check if other DP workers have running batches
|
||||||
if local_batch is None:
|
if local_batch is None:
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
|
global_num_tokens_for_logprob = 0
|
||||||
elif local_batch.forward_mode.is_decode():
|
elif local_batch.forward_mode.is_decode():
|
||||||
num_tokens = local_batch.batch_size()
|
num_tokens = local_batch.batch_size()
|
||||||
|
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
|
||||||
|
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
|
||||||
|
global_num_tokens_for_logprob = num_tokens
|
||||||
else:
|
else:
|
||||||
num_tokens = local_batch.extend_num_tokens
|
num_tokens = local_batch.extend_num_tokens
|
||||||
|
global_num_tokens_for_logprob = sum(
|
||||||
|
[
|
||||||
|
# We should have at least 1 token for sample in every case.
|
||||||
|
max(extend_len - logprob_start_len, 1)
|
||||||
|
for logprob_start_len, extend_len in zip(
|
||||||
|
local_batch.extend_logprob_start_lens, local_batch.extend_lens
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
|
||||||
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
can_cuda_graph = 1
|
||||||
|
else:
|
||||||
|
can_cuda_graph = 0
|
||||||
|
|
||||||
|
if not self.spec_algorithm.is_none():
|
||||||
|
# TODO(sang): Support cuda graph when idle batch is there.
|
||||||
|
if local_batch is None or local_batch.forward_mode.is_idle():
|
||||||
|
can_cuda_graph = 0
|
||||||
|
|
||||||
|
is_extend_in_batch = (
|
||||||
|
local_batch.forward_mode.is_extend() if local_batch else False
|
||||||
|
)
|
||||||
|
local_info = torch.tensor(
|
||||||
|
[
|
||||||
|
num_tokens,
|
||||||
|
can_cuda_graph,
|
||||||
|
global_num_tokens_for_logprob,
|
||||||
|
is_extend_in_batch,
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
global_info = torch.empty(
|
||||||
|
(self.server_args.dp_size, self.attn_tp_size, 4),
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
torch.distributed.all_gather_into_tensor(
|
torch.distributed.all_gather_into_tensor(
|
||||||
global_num_tokens,
|
global_info.flatten(),
|
||||||
local_num_tokens,
|
local_info,
|
||||||
group=self.tp_cpu_group,
|
group=self.tp_cpu_group,
|
||||||
)
|
)
|
||||||
|
global_num_tokens = global_info[:, 0, 0].tolist()
|
||||||
|
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
||||||
|
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
||||||
|
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
||||||
|
|
||||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
if local_batch is None and max(global_num_tokens) > 0:
|
||||||
local_batch = self.get_idle_batch()
|
local_batch = self.get_idle_batch()
|
||||||
|
|
||||||
if local_batch is not None:
|
if local_batch is not None:
|
||||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
local_batch.global_num_tokens = global_num_tokens
|
||||||
|
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
||||||
|
|
||||||
# Check forward mode for cuda graph
|
# Check forward mode for cuda graph
|
||||||
if not self.server_args.disable_cuda_graph:
|
if not self.server_args.disable_cuda_graph:
|
||||||
forward_mode_state = torch.tensor(
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||||
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
forward_mode_state,
|
|
||||||
op=torch.distributed.ReduceOp.MIN,
|
|
||||||
group=self.tp_cpu_group,
|
|
||||||
)
|
|
||||||
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
|
||||||
|
|
||||||
return local_batch
|
return local_batch, any(is_extend_in_batch)
|
||||||
|
|
||||||
def get_idle_batch(self):
|
def get_idle_batch(self):
|
||||||
idle_batch = ScheduleBatch.init_new(
|
idle_batch = ScheduleBatch.init_new(
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
|||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||||
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||||
self.tp_size = model_runner.server_args.tp_size
|
self.tp_size = model_runner.server_args.tp_size
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
|
|
||||||
@@ -236,7 +237,7 @@ class CudaGraphRunner:
|
|||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.gathered_buffer = torch.zeros(
|
self.gathered_buffer = torch.zeros(
|
||||||
(
|
(
|
||||||
self.max_bs * self.dp_size,
|
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
||||||
self.model_runner.model_config.hidden_size,
|
self.model_runner.model_config.hidden_size,
|
||||||
),
|
),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
@@ -276,13 +277,12 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
min_num_tokens, max_num_tokens = min(
|
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||||
forward_batch.global_num_tokens_cpu
|
|
||||||
), max(forward_batch.global_num_tokens_cpu)
|
|
||||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||||
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
total_global_tokens in self.graphs
|
||||||
if self.disable_padding
|
if self.disable_padding
|
||||||
else max_num_tokens <= self.max_bs
|
else total_global_tokens <= self.max_bs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
is_bs_supported = (
|
is_bs_supported = (
|
||||||
@@ -304,6 +304,9 @@ class CudaGraphRunner:
|
|||||||
def capture(self):
|
def capture(self):
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
self.stream = graph_capture_context.stream
|
self.stream = graph_capture_context.stream
|
||||||
|
avail_mem = get_available_gpu_memory(
|
||||||
|
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
|
||||||
|
)
|
||||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||||
capture_range = (
|
capture_range = (
|
||||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||||
@@ -311,6 +314,16 @@ class CudaGraphRunner:
|
|||||||
else reversed(self.capture_bs)
|
else reversed(self.capture_bs)
|
||||||
)
|
)
|
||||||
for bs in capture_range:
|
for bs in capture_range:
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
avail_mem = get_available_gpu_memory(
|
||||||
|
self.model_runner.device,
|
||||||
|
self.model_runner.gpu_id,
|
||||||
|
empty_cache=False,
|
||||||
|
)
|
||||||
|
capture_range.set_description(
|
||||||
|
f"Capturing batches ({avail_mem=:.2f} GB)"
|
||||||
|
)
|
||||||
|
|
||||||
with patch_model(
|
with patch_model(
|
||||||
self.model_runner.model,
|
self.model_runner.model,
|
||||||
bs in self.compile_bs,
|
bs in self.compile_bs,
|
||||||
@@ -345,8 +358,18 @@ class CudaGraphRunner:
|
|||||||
mrope_positions = self.mrope_positions[:, :bs]
|
mrope_positions = self.mrope_positions[:, :bs]
|
||||||
|
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
global_num_tokens = [bs] * self.tp_size
|
self.global_num_tokens_gpu.copy_(
|
||||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
torch.tensor(
|
||||||
|
[
|
||||||
|
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
||||||
|
for i in range(self.dp_size)
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
else:
|
else:
|
||||||
global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
@@ -371,7 +394,7 @@ class CudaGraphRunner:
|
|||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
global_num_tokens_cpu=global_num_tokens,
|
global_num_tokens_gpu=global_num_tokens,
|
||||||
gathered_buffer=gathered_buffer,
|
gathered_buffer=gathered_buffer,
|
||||||
mrope_positions=mrope_positions,
|
mrope_positions=mrope_positions,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
@@ -392,6 +415,9 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
def run_once():
|
def run_once():
|
||||||
|
# Clean intermediate result cache for DP attention
|
||||||
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
|
|
||||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||||
return logits_output.next_token_logits, logits_output.hidden_states
|
return logits_output.next_token_logits, logits_output.hidden_states
|
||||||
|
|
||||||
@@ -426,7 +452,7 @@ class CudaGraphRunner:
|
|||||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||||
self.capture()
|
self.capture()
|
||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
def replay_prepare(self, forward_batch: ForwardBatch):
|
||||||
self.recapture_if_needed(forward_batch)
|
self.recapture_if_needed(forward_batch)
|
||||||
|
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
@@ -435,7 +461,7 @@ class CudaGraphRunner:
|
|||||||
# Pad
|
# Pad
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
index = bisect.bisect_left(
|
index = bisect.bisect_left(
|
||||||
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
|
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
@@ -459,6 +485,8 @@ class CudaGraphRunner:
|
|||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
if forward_batch.mrope_positions is not None:
|
if forward_batch.mrope_positions is not None:
|
||||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||||
|
if self.enable_dp_attention:
|
||||||
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
|
|
||||||
if hasattr(forward_batch.spec_info, "hidden_states"):
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
||||||
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
||||||
@@ -475,14 +503,29 @@ class CudaGraphRunner:
|
|||||||
seq_lens_cpu=self.seq_lens_cpu,
|
seq_lens_cpu=self.seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Store fields
|
||||||
|
self.raw_bs = raw_bs
|
||||||
|
self.raw_num_token = raw_num_token
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
||||||
|
if not skip_attn_backend_init:
|
||||||
|
self.replay_prepare(forward_batch)
|
||||||
|
else:
|
||||||
|
# In speculative decoding, these two fields are still needed.
|
||||||
|
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
||||||
|
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.graphs[bs].replay()
|
self.graphs[self.bs].replay()
|
||||||
next_token_logits, hidden_states = self.output_buffers[bs]
|
next_token_logits, hidden_states = self.output_buffers[self.bs]
|
||||||
|
|
||||||
logits_output = LogitsProcessorOutput(
|
logits_output = LogitsProcessorOutput(
|
||||||
next_token_logits=next_token_logits[:raw_num_token],
|
next_token_logits=next_token_logits[: self.raw_num_token],
|
||||||
hidden_states=(
|
hidden_states=(
|
||||||
hidden_states[:raw_num_token] if hidden_states is not None else None
|
hidden_states[: self.raw_num_token]
|
||||||
|
if hidden_states is not None
|
||||||
|
else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return logits_output
|
return logits_output
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
@@ -263,15 +263,24 @@ class ForwardBatch:
|
|||||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# For DP attention
|
||||||
if batch.global_num_tokens is not None:
|
if batch.global_num_tokens is not None:
|
||||||
ret.global_num_tokens_cpu = batch.global_num_tokens
|
ret.global_num_tokens_cpu = batch.global_num_tokens
|
||||||
max_len = max(ret.global_num_tokens_cpu)
|
ret.global_num_tokens_gpu = torch.tensor(
|
||||||
|
batch.global_num_tokens, dtype=torch.int64
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
||||||
|
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
||||||
|
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
||||||
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
|
sum_len = sum(batch.global_num_tokens)
|
||||||
ret.gathered_buffer = torch.zeros(
|
ret.gathered_buffer = torch.zeros(
|
||||||
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
(sum_len, model_runner.model_config.hidden_size),
|
||||||
dtype=model_runner.dtype,
|
dtype=model_runner.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ret.forward_mode.is_idle():
|
if ret.forward_mode.is_idle():
|
||||||
ret.positions = torch.empty((0,), device=device)
|
ret.positions = torch.empty((0,), device=device)
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group,
|
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||||
decode_attention_fwd_grouped_rope,
|
decode_attention_fwd_grouped_rope,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
dp_gather,
|
||||||
|
dp_scatter,
|
||||||
|
get_attention_dp_size,
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
)
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
layer_id=None,
|
layer_id=None,
|
||||||
|
reduce_results: bool = True,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
self.v_head_dim = v_head_dim
|
self.v_head_dim = v_head_dim
|
||||||
self.q_lora_rank = q_lora_rank
|
self.q_lora_rank = q_lora_rank
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
|
||||||
|
self.dp_size = get_attention_dp_size()
|
||||||
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
|
attn_tp_size = get_attention_tp_size()
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
assert num_heads % attn_tp_size == 0
|
||||||
assert num_heads % tp_size == 0
|
self.num_local_heads = num_heads // attn_tp_size
|
||||||
self.num_local_heads = num_heads // tp_size
|
|
||||||
self.scaling = self.qk_head_dim**-0.5
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("q_proj", prefix),
|
prefix=add_prefix("q_proj", prefix),
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("o_proj", prefix),
|
prefix=add_prefix("o_proj", prefix),
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
)
|
)
|
||||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||||
self.rotary_emb = get_rope_wrapper(
|
self.rotary_emb = get_rope_wrapper(
|
||||||
@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hidden_states.shape[0] == 0:
|
||||||
|
assert (
|
||||||
|
not self.o_proj.reduce_results
|
||||||
|
), "short-circuiting allreduce will lead to hangs"
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q = self.q_a_proj(hidden_states)[0]
|
q = self.q_a_proj(hidden_states)[0]
|
||||||
q = self.q_a_layernorm(q)
|
q = self.q_a_layernorm(q)
|
||||||
@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
layer_id=None,
|
reduce_results: bool = True,
|
||||||
use_dp=False,
|
layer_id: int = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.v_head_dim = v_head_dim
|
self.v_head_dim = v_head_dim
|
||||||
self.q_lora_rank = q_lora_rank
|
self.q_lora_rank = q_lora_rank
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.dp_size = get_attention_dp_size()
|
||||||
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
|
attn_tp_size = get_attention_tp_size()
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
assert num_heads % attn_tp_size == 0
|
||||||
assert num_heads % tp_size == 0
|
self.num_local_heads = num_heads // attn_tp_size
|
||||||
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
|
||||||
self.scaling = self.qk_head_dim**-0.5
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
if use_dp:
|
# For tensor parallel attention
|
||||||
# For data parallel attention
|
if self.q_lora_rank is not None:
|
||||||
if self.q_lora_rank is not None:
|
self.q_a_proj = ReplicatedLinear(
|
||||||
self.q_a_proj = ReplicatedLinear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.q_lora_rank,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("q_a_proj", prefix),
|
|
||||||
)
|
|
||||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
|
||||||
self.q_b_proj = ReplicatedLinear(
|
|
||||||
q_lora_rank,
|
|
||||||
self.num_heads * self.qk_head_dim,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("q_b_proj", prefix),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.q_proj = ReplicatedLinear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.num_heads * self.qk_head_dim,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("q_proj", prefix),
|
|
||||||
)
|
|
||||||
self.kv_b_proj = ReplicatedLinear(
|
|
||||||
self.kv_lora_rank,
|
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("kv_b_proj", prefix),
|
|
||||||
)
|
|
||||||
# O projection.
|
|
||||||
self.o_proj = ReplicatedLinear(
|
|
||||||
self.num_heads * self.v_head_dim,
|
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
|
self.q_lora_rank,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("o_proj", prefix),
|
prefix=add_prefix("q_a_proj", prefix),
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||||
|
self.q_b_proj = ColumnParallelLinear(
|
||||||
|
q_lora_rank,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("q_b_proj", prefix),
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# For tensor parallel attention
|
self.q_proj = ColumnParallelLinear(
|
||||||
if self.q_lora_rank is not None:
|
|
||||||
self.q_a_proj = ReplicatedLinear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.q_lora_rank,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("q_a_proj", prefix),
|
|
||||||
)
|
|
||||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
|
||||||
self.q_b_proj = ColumnParallelLinear(
|
|
||||||
q_lora_rank,
|
|
||||||
self.num_heads * self.qk_head_dim,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("q_b_proj", prefix),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.q_proj = ColumnParallelLinear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.num_heads * self.qk_head_dim,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("q_proj", prefix),
|
|
||||||
)
|
|
||||||
self.kv_b_proj = ColumnParallelLinear(
|
|
||||||
self.kv_lora_rank,
|
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("kv_b_proj", prefix),
|
|
||||||
)
|
|
||||||
# O projection.
|
|
||||||
self.o_proj = RowParallelLinear(
|
|
||||||
self.num_heads * self.v_head_dim,
|
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("o_proj", prefix),
|
prefix=add_prefix("q_proj", prefix),
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
)
|
)
|
||||||
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("kv_b_proj", prefix),
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
|
)
|
||||||
|
# O projection.
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
reduce_results=reduce_results,
|
||||||
|
prefix=add_prefix("o_proj", prefix),
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
@@ -542,38 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
self.w_vc = None
|
self.w_vc = None
|
||||||
self.w_scale = None
|
self.w_scale = None
|
||||||
|
|
||||||
|
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
||||||
|
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||||
|
"flashinfer_mla_disable_ragged"
|
||||||
|
]
|
||||||
|
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||||
|
|
||||||
|
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||||
|
if self.enable_flashinfer_mla:
|
||||||
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||||
|
return (
|
||||||
|
not self.flashinfer_mla_disable_ragged
|
||||||
|
and forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
and forward_batch.extend_prefix_lens.sum() == 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||||
|
return (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
and forward_batch.extend_prefix_lens.sum() == 0
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if hidden_states.shape[0] == 0:
|
||||||
|
assert (
|
||||||
|
not self.o_proj.reduce_results
|
||||||
|
), "short-circuiting allreduce will lead to hangs"
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def no_absorb() -> bool:
|
if self.no_absorb(forward_batch):
|
||||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
|
||||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
|
||||||
return (
|
|
||||||
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
|
||||||
and forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
and forward_batch.extend_prefix_lens.sum() == 0
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
|
||||||
return (
|
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
and forward_batch.extend_prefix_lens.sum() == 0
|
|
||||||
)
|
|
||||||
|
|
||||||
if no_absorb():
|
|
||||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||||
else:
|
else:
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
if (
|
if (
|
||||||
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
self.rocm_fused_decode_mla
|
||||||
and forward_batch.forward_mode.is_decode()
|
and forward_batch.forward_mode.is_decode()
|
||||||
):
|
):
|
||||||
return self.forward_absorb_fused_mla_rope(
|
return self.forward_absorb_fused_mla_rope(
|
||||||
@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def all_gather(
|
|
||||||
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
|
||||||
):
|
|
||||||
all_lens = forward_batch.global_num_tokens_cpu
|
|
||||||
max_len = max(forward_batch.global_num_tokens_cpu)
|
|
||||||
|
|
||||||
if world_size == 1:
|
|
||||||
return input_tensor, 0, all_lens[0]
|
|
||||||
|
|
||||||
padded_tensor = torch.nn.functional.pad(
|
|
||||||
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
|
||||||
)
|
|
||||||
|
|
||||||
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
|
|
||||||
|
|
||||||
gathered_tensors = torch.concat(
|
|
||||||
[
|
|
||||||
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
|
||||||
for i in range(world_size)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
|
||||||
end_index = start_index + all_lens[rank]
|
|
||||||
|
|
||||||
return gathered_tensors, start_index, end_index
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
self.enable_dp_attention = (
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||||
not global_server_args_dict["disable_mla"]
|
self.layer_id = layer_id
|
||||||
and global_server_args_dict["enable_dp_attention"]
|
self.dp_size = get_attention_dp_size()
|
||||||
)
|
|
||||||
if self.enable_dp_attention:
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.tp_group = get_tp_group()
|
|
||||||
if not global_server_args_dict["disable_mla"]:
|
if not global_server_args_dict["disable_mla"]:
|
||||||
self.self_attn = DeepseekV2AttentionMLA(
|
self.self_attn = DeepseekV2AttentionMLA(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
use_dp=self.enable_dp_attention,
|
reduce_results=False,
|
||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
reduce_results=False,
|
||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_nextn or (
|
if is_nextn or (
|
||||||
config.n_routed_experts is not None
|
config.n_routed_experts is not None
|
||||||
and layer_id >= config.first_k_dense_replace
|
and layer_id >= config.first_k_dense_replace
|
||||||
@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
if residual is None:
|
||||||
if not forward_batch.forward_mode.is_idle():
|
residual = hidden_states
|
||||||
if residual is None:
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
residual = hidden_states
|
else:
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
else:
|
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
# Scatter
|
||||||
positions=positions,
|
if self.dp_size != 1:
|
||||||
hidden_states=hidden_states,
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||||
forward_batch=forward_batch,
|
# be careful about this!
|
||||||
)
|
hidden_states, global_hidden_states = (
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||||
hidden_states, residual
|
hidden_states,
|
||||||
)
|
)
|
||||||
|
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gather
|
||||||
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
|
# all gather and all reduce
|
||||||
|
if self.dp_size != 1:
|
||||||
|
hidden_states, local_hidden_states = (
|
||||||
|
forward_batch.gathered_buffer,
|
||||||
|
hidden_states,
|
||||||
|
)
|
||||||
|
dp_gather(
|
||||||
|
hidden_states, local_hidden_states, forward_batch, self.layer_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
|
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
if self.enable_dp_attention:
|
hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states, start_idx, end_idx = all_gather(
|
|
||||||
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
|
||||||
)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = hidden_states[start_idx:end_idx]
|
|
||||||
else:
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
|
|||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.dp_size = get_attention_dp_size()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# Gather
|
||||||
|
if self.dp_size != 1:
|
||||||
|
input_ids, local_input_ids = (
|
||||||
|
torch.empty(
|
||||||
|
(forward_batch.gathered_buffer.shape[0],),
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
device=input_ids.device,
|
||||||
|
),
|
||||||
|
input_ids,
|
||||||
|
)
|
||||||
|
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
@@ -1059,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self.model = DeepseekV2Model(
|
self.model = DeepseekV2Model(
|
||||||
config, quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
if global_server_args_dict["enable_dp_attention"]:
|
self.lm_head = ParallelLMHead(
|
||||||
self.lm_head = ReplicatedLinear(
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
config.vocab_size,
|
quant_config=quant_config,
|
||||||
bias=False,
|
prefix=add_prefix("lm_head", prefix),
|
||||||
prefix=add_prefix("lm_head", prefix),
|
)
|
||||||
)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
self.dp_size = get_attention_dp_size()
|
||||||
else:
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("lm_head", prefix),
|
|
||||||
)
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||||
|
|
||||||
|
if self.dp_size != 1:
|
||||||
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||||
|
# be careful about this!
|
||||||
|
hidden_states, global_hidden_states = (
|
||||||
|
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||||
|
hidden_states,
|
||||||
|
)
|
||||||
|
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||||
|
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -262,14 +262,14 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Data parallelism attention
|
# Data parallelism attention
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.dp_size = self.tp_size
|
|
||||||
assert self.tp_size % self.dp_size == 0
|
|
||||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
|
||||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||||
|
assert (
|
||||||
|
self.dp_size > 1
|
||||||
|
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
|
||||||
|
assert self.tp_size % self.dp_size == 0
|
||||||
|
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||||
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
|
||||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Speculative Decoding
|
# Speculative Decoding
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase):
|
|||||||
"--tp",
|
"--tp",
|
||||||
"2",
|
"2",
|
||||||
"--enable-dp-attention",
|
"--enable-dp-attention",
|
||||||
|
"--dp",
|
||||||
|
"2",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user