Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
@@ -103,7 +104,20 @@ class ModelConfig:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_config.v_head_dim
|
||||
|
||||
# Handle rope scaling with yarn
|
||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
if self.hf_config.rope_scaling:
|
||||
mscale_all_dim = self.hf_config.rope_scaling.get(
|
||||
"mscale_all_dim", False
|
||||
)
|
||||
scaling_factor = self.hf_config.rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
|
||||
self.head_dim = 128
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
@@ -414,3 +428,9 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
521
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Normal file
521
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Normal file
@@ -0,0 +1,521 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Support attention backend for flashinfer MLA.
|
||||
When radix cache is enabled, the backend only uses BatchMLAPaged wrapper when forwarding.
|
||||
When radix cache is disabled, the backend uses BatchPrefill wrappers for prefilling (with or without prefix cache),
|
||||
and uses BatchMLAPaged wrapper for decoding.
|
||||
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
should_use_tensor_core,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeMetadata:
|
||||
decode_wrapper: BatchMLAPagedAttentionWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefillMetadata:
|
||||
prefill_wrapper: Union[
|
||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
||||
]
|
||||
use_ragged: bool
|
||||
|
||||
|
||||
# Reuse this workspace buffer across all flashinfer wrappers
|
||||
global_workspace_buffer = None
|
||||
|
||||
|
||||
class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
"""Flashinfer attention kernels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Parse constants
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
|
||||
global_config.enable_flashinfer_mla = True
|
||||
|
||||
# Allocate buffers
|
||||
global global_workspace_buffer
|
||||
if global_workspace_buffer is None:
|
||||
global_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
)
|
||||
|
||||
if not global_server_args_dict["disable_radix_cache"]:
|
||||
# use mla paged prefill
|
||||
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
backend="auto",
|
||||
)
|
||||
else:
|
||||
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
backend="auto",
|
||||
)
|
||||
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer, backend="auto"
|
||||
)
|
||||
|
||||
# Create indices updater
|
||||
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
||||
model_runner, self
|
||||
)
|
||||
|
||||
# Other metadata
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {}
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
decode_wrapper=self.decode_wrapper,
|
||||
)
|
||||
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
use_ragged = global_server_args_dict["disable_radix_cache"]
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
prefill_wrapper_paged=self.prefill_wrapper_paged,
|
||||
use_ragged=use_ragged,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrapper_paged, use_ragged
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
):
|
||||
if kv_indices_buf is None:
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
else:
|
||||
cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
|
||||
self.cuda_graph_qo_indptr = self.kv_indptr.clone()
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=self.qo_indptr[: num_tokens + 1],
|
||||
kv_indptr=self.kv_indptr[: num_tokens + 1],
|
||||
kv_indices=self.cuda_graph_kv_indices,
|
||||
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
||||
backend="auto",
|
||||
)
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
decode_wrapper=decode_wrapper,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
||||
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
logits_soft_cap = layer.logit_cap
|
||||
|
||||
if not global_server_args_dict["disable_radix_cache"]:
|
||||
# use mla paged prefill
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
|
||||
o = prefill_wrapper_paged.run(
|
||||
qall[:, :, : layer.v_head_dim],
|
||||
qall[:, :, layer.v_head_dim :],
|
||||
k_buf[:, :, : layer.v_head_dim],
|
||||
k_buf[:, :, layer.v_head_dim :],
|
||||
)
|
||||
else:
|
||||
# use mla ragged prefill
|
||||
o, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
||||
causal=True,
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
)
|
||||
|
||||
# FIXME: Here should be another prefill_paged to call
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
decode_wrapper = self.forward_metadata.decode_wrapper
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
||||
o = decode_wrapper.run(
|
||||
reshaped_q[:, :, : layer.v_head_dim],
|
||||
reshaped_q[:, :, layer.v_head_dim :],
|
||||
reshaped_k[:, :, : layer.v_head_dim],
|
||||
reshaped_k[:, :, layer.v_head_dim :],
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
|
||||
class FlashInferMLAIndicesUpdaterDecode:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Parse Constants
|
||||
self.num_local_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.scaling = model_runner.model_config.scaling
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
||||
):
|
||||
decode_wrappers = decode_wrapper or self.decode_wrapper
|
||||
self.call_begin_forward(
|
||||
decode_wrapper,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
self.kv_indptr,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper: BatchMLAPagedAttentionWrapper,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
sm_scale = self.scaling
|
||||
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_lens,
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
False,
|
||||
sm_scale,
|
||||
self.data_type,
|
||||
self.data_type,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAIndicesUpdaterPrefill:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Parse Constants
|
||||
self.num_qo_heads = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.v_head_dim = model_runner.model_config.v_head_dim
|
||||
self.scaling = model_runner.model_config.scaling
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
self.kv_indptr = attn_backend.kv_indptr
|
||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
||||
self.qo_indptr = attn_backend.qo_indptr
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tnesor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrapper_paged: Union[
|
||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
||||
],
|
||||
use_ragged: bool,
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
self.call_begin_forward(
|
||||
self.prefill_wrapper_ragged,
|
||||
prefill_wrapper_paged,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
self.kv_indptr,
|
||||
self.qo_indptr,
|
||||
use_ragged,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
||||
wrapper_paged: Union[
|
||||
BatchPrefillWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
||||
],
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
seq_lens: torch.Tensor,
|
||||
prefix_lens: torch.Tensor,
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
# Normal extend
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum,
|
||||
dtype=torch.int32,
|
||||
device=req_pool_indices.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
sm_scale = self.scaling
|
||||
|
||||
# extend part
|
||||
if use_ragged:
|
||||
wrapper_ragged.begin_forward(
|
||||
qo_indptr=qo_indptr,
|
||||
kv_indptr=qo_indptr,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
head_dim_vo=self.v_head_dim,
|
||||
q_data_type=self.q_data_type,
|
||||
)
|
||||
|
||||
if not global_server_args_dict["disable_radix_cache"]:
|
||||
# mla paged prefill
|
||||
kv_len_arr = kv_indptr[1:] - kv_indptr[:-1]
|
||||
wrapper_paged.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_len_arr,
|
||||
self.num_qo_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
1,
|
||||
True,
|
||||
sm_scale,
|
||||
self.q_data_type,
|
||||
self.data_type,
|
||||
)
|
||||
|
||||
# FIXME: Here should be some logic for prefill paged when not using radix cache?
|
||||
@@ -34,6 +34,7 @@ from sglang.srt.distributed import (
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
@@ -113,9 +114,9 @@ class ModelRunner:
|
||||
if self.server_args.device != "cpu":
|
||||
if server_args.enable_flashinfer_mla:
|
||||
logger.info(
|
||||
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
)
|
||||
self.server_args.attention_backend = "flashinfer"
|
||||
self.server_args.attention_backend = "flashinfer_mla"
|
||||
else:
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
@@ -703,6 +704,8 @@ class ModelRunner:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
@@ -510,25 +510,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
if global_server_args_dict["disable_radix_cache"]:
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
|
||||
def no_absorb() -> bool:
|
||||
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||
# Flashinfer MLA: Only do not use absorb when prefilling/extending without radix cache
|
||||
return (
|
||||
global_server_args_dict["disable_radix_cache"]
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
return (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
|
||||
if no_absorb():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
):
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user