Support nextn for flashinfer mla attention backend (#4218)
This commit is contained in:
@@ -84,7 +84,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
|
||||
|
||||
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
|
||||
|
||||
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
|
||||
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. Currently when using flashinfer mla wrapper and speculative decoding together, the `speculative_eagle_topk` parameter should be set to 1.
|
||||
|
||||
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
|
||||
|
||||
|
||||
@@ -11,9 +11,10 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -23,6 +24,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Parse constants
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
global_config.enable_flashinfer_mla = True
|
||||
|
||||
@@ -78,35 +84,51 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.q_indptr_decode = torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
if q_indptr_decode_buf is None:
|
||||
self.q_indptr_decode = torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
self.q_indptr_decode = q_indptr_decode_buf
|
||||
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
)
|
||||
|
||||
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
backend="auto",
|
||||
)
|
||||
if not self.skip_prefill:
|
||||
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
backend="auto",
|
||||
)
|
||||
|
||||
# FlashinferMLA backend uses mla wrapper for target verify
|
||||
self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
backend="auto",
|
||||
)
|
||||
|
||||
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer, backend="auto"
|
||||
)
|
||||
|
||||
# Create indices updater
|
||||
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
if not skip_prefill:
|
||||
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
|
||||
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
||||
model_runner, self
|
||||
)
|
||||
@@ -114,7 +136,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
# Other metadata
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {} # For verify
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
@@ -126,6 +148,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
init_metadata_replay=False,
|
||||
)
|
||||
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper_paged=self.prefill_wrapper_paged,
|
||||
use_ragged=False,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper_paged=self.prefill_wrapper_verify,
|
||||
use_ragged=False,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
@@ -202,10 +246,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
seq_lens_sum,
|
||||
decode_wrapper=decode_wrapper,
|
||||
init_metadata_replay=False,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
||||
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
||||
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
|
||||
elif forward_mode.is_target_verify():
|
||||
verify_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
|
||||
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
|
||||
kv_indices=self.cuda_graph_kv_indices,
|
||||
kv_len_arr=self.cuda_graph_kv_lens[:bs],
|
||||
backend="auto",
|
||||
)
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper_paged=verify_wrapper,
|
||||
use_ragged=False,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
|
||||
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
|
||||
@@ -221,6 +288,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
assert seq_lens_cpu is not None
|
||||
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
||||
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
||||
kv_len_arr_cpu, dim=0
|
||||
@@ -239,8 +307,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
seq_lens_sum,
|
||||
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
||||
init_metadata_replay=True,
|
||||
spec_info=spec_info,
|
||||
**self.fast_decode_kwargs,
|
||||
)
|
||||
elif forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
|
||||
use_ragged=False,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
@@ -254,7 +333,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
save_kv_cache: bool = True,
|
||||
):
|
||||
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
@@ -297,7 +376,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
save_kv_cache: bool = True,
|
||||
):
|
||||
decode_wrapper = self.forward_metadata.decode_wrapper
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
@@ -349,6 +428,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
||||
init_metadata_replay: bool = False,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
**fast_decode_kwargs,
|
||||
):
|
||||
decode_wrapper = decode_wrapper or self.decode_wrapper
|
||||
@@ -360,6 +440,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
self.q_indptr,
|
||||
self.kv_indptr,
|
||||
init_metadata_replay,
|
||||
spec_info,
|
||||
**fast_decode_kwargs,
|
||||
)
|
||||
|
||||
@@ -372,30 +453,33 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
q_indptr: torch.Tensor,
|
||||
kv_indptr: torch.Tensor,
|
||||
init_metadata_replay: bool = False,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
**fast_decode_kwargs,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
q_indptr = q_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = (
|
||||
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
||||
if not init_metadata_replay
|
||||
else fast_decode_kwargs["kv_indices"]
|
||||
)
|
||||
|
||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||
sm_scale = self.scaling
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = (
|
||||
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
||||
if not init_metadata_replay
|
||||
else fast_decode_kwargs["kv_indices"]
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
if not init_metadata_replay:
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
@@ -457,6 +541,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
@@ -476,6 +561,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
self.kv_indptr,
|
||||
self.qo_indptr,
|
||||
use_ragged,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
@@ -490,29 +576,46 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum,
|
||||
dtype=torch.int32,
|
||||
device=req_pool_indices.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
bs = len(seq_lens)
|
||||
sm_scale = self.scaling
|
||||
|
||||
if spec_info is None:
|
||||
assert len(seq_lens) == len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum,
|
||||
dtype=torch.int32,
|
||||
device=req_pool_indices.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
||||
spec_info, EagleVerifyInput
|
||||
)
|
||||
# TODO: Support topk > 1 with custom mask
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
if use_ragged:
|
||||
# ragged prefill
|
||||
wrapper_ragged.begin_forward(
|
||||
@@ -543,6 +646,163 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAMultiStepDraftBackend:
|
||||
"""
|
||||
Wrap multiple flashinfer mla attention backends as one for multiple consecutive
|
||||
draft decoding steps.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
|
||||
if topk > 1:
|
||||
raise ValueError(
|
||||
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
||||
)
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||
self.kv_indptr = torch.zeros(
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
max_bs + 1,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.q_indptr_decode = torch.arange(
|
||||
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
FlashInferMLAAttnBackend(
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
q_indptr_decode_buf=self.q_indptr_decode,
|
||||
)
|
||||
)
|
||||
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
|
||||
def common_template(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
kv_indices_buffer: torch.Tensor,
|
||||
call_fn: Callable,
|
||||
):
|
||||
num_seqs = forward_batch.batch_size
|
||||
bs = self.topk * num_seqs
|
||||
seq_lens_sum = forward_batch.seq_lens_sum
|
||||
|
||||
self.generate_draft_decode_kv_indices[
|
||||
(self.speculative_num_steps, num_seqs, self.topk)
|
||||
](
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.seq_lens,
|
||||
kv_indices_buffer,
|
||||
self.kv_indptr,
|
||||
forward_batch.positions,
|
||||
num_seqs,
|
||||
self.topk,
|
||||
self.pool_len,
|
||||
kv_indices_buffer.shape[1],
|
||||
self.kv_indptr.shape[1],
|
||||
triton.next_power_of_2(num_seqs),
|
||||
triton.next_power_of_2(self.speculative_num_steps),
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
]
|
||||
call_fn(i, forward_batch)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
kv_indices = torch.zeros(
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
)
|
||||
forward_batch.spec_info.kv_indices = (
|
||||
forward_batch.spec_info.kv_indices.clone()
|
||||
)
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
self.common_template(forward_batch, kv_indices, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
forward_batch.batch_size,
|
||||
forward_batch.batch_size * self.topk,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
seq_lens_sum=-1,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
|
||||
def fast_mla_decode_plan(
|
||||
self,
|
||||
qo_indptr_cpu: torch.Tensor,
|
||||
|
||||
@@ -555,6 +555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return (
|
||||
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and forward_batch.extend_prefix_lens.sum() == 0
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -123,6 +123,16 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
||||
self.model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
@@ -100,5 +101,67 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
|
||||
self.assertGreater(metrics["accuracy"], 0.62)
|
||||
|
||||
|
||||
class TestFlashinferMLAMTP(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "lmsys/sglang-ci-dsv3-test"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
other_args = ["--trust-remote-code"]
|
||||
if torch.cuda.is_available() and torch.version.cuda:
|
||||
other_args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
"2",
|
||||
"--disable-radix",
|
||||
"--enable-torch-compile",
|
||||
"--torch-compile-max-bs",
|
||||
"1",
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft",
|
||||
"lmsys/sglang-ci-dsv3-test-NextN",
|
||||
"--speculative-num-steps",
|
||||
"4",
|
||||
"--speculative-eagle-topk",
|
||||
"1",
|
||||
"--speculative-num-draft-tokens",
|
||||
"4",
|
||||
"--enable-flashinfer-mla",
|
||||
]
|
||||
)
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
requests.get(self.base_url + "/flush_cache")
|
||||
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||
print(f"{avg_spec_accept_length=}")
|
||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user