Add fast decode plan for flashinfer mla (#3987)
This commit is contained in:
@@ -186,5 +186,5 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
|
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
|
||||||
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
|
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
|
||||||
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
|
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
|
||||||
* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models.
|
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden.
|
||||||
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
|
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
|
|||||||
|
|
||||||
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
|
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
|
||||||
|
|
||||||
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage)
|
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
|
||||||
|
|
||||||
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
|
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
|
||||||
|
|
||||||
|
|||||||
@@ -29,9 +29,8 @@ class AttentionBackend(ABC):
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -42,9 +41,8 @@ class AttentionBackend(ABC):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -269,9 +269,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
@@ -339,9 +340,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -27,14 +28,12 @@ from sglang.srt.utils import is_flashinfer_available
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
BatchMLAPagedAttentionWrapper,
|
BatchMLAPagedAttentionWrapper,
|
||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -63,6 +62,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
# Parse constants
|
# Parse constants
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
self.device = model_runner.device
|
||||||
|
|
||||||
global_config.enable_flashinfer_mla = True
|
global_config.enable_flashinfer_mla = True
|
||||||
|
|
||||||
@@ -85,10 +85,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_last_page_len = torch.ones(
|
|
||||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
|
||||||
)
|
|
||||||
|
|
||||||
self.q_indptr_decode = torch.arange(
|
self.q_indptr_decode = torch.arange(
|
||||||
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
@@ -126,6 +122,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
decode_wrapper=self.decode_wrapper,
|
decode_wrapper=self.decode_wrapper,
|
||||||
|
init_metadata_replay=False,
|
||||||
)
|
)
|
||||||
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
||||||
else:
|
else:
|
||||||
@@ -161,13 +158,20 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
cuda_graph_kv_indices = kv_indices_buf
|
cuda_graph_kv_indices = kv_indices_buf
|
||||||
|
|
||||||
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||||
self.cuda_graph_custom_mask = torch.zeros(
|
self.cuda_graph_qo_indptr = self.q_indptr_decode.clone()
|
||||||
(max_bs * self.max_context_len),
|
self.cuda_graph_kv_indptr = self.kv_indptr.clone()
|
||||||
dtype=torch.uint8,
|
self.cuda_graph_kv_lens = torch.ones(
|
||||||
device="cuda",
|
(max_bs,), dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
self.cuda_graph_qk_indptr = self.kv_indptr.clone()
|
|
||||||
self.cuda_graph_qo_indptr = self.kv_indptr.clone()
|
# For fast decode plan in graph replaying
|
||||||
|
self.cuda_graph_qo_indptr_cpu = self.cuda_graph_qo_indptr.to("cpu")
|
||||||
|
self.cuda_graph_kv_indptr_cpu = self.cuda_graph_kv_indptr.to("cpu")
|
||||||
|
self.fast_decode_kwargs = {
|
||||||
|
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu,
|
||||||
|
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu,
|
||||||
|
"kv_indices": self.cuda_graph_kv_indices,
|
||||||
|
}
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -175,18 +179,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
**kwargs,
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
qo_indptr=self.qo_indptr[: num_tokens + 1],
|
qo_indptr=self.cuda_graph_qo_indptr[: num_tokens + 1],
|
||||||
kv_indptr=self.kv_indptr[: num_tokens + 1],
|
kv_indptr=self.cuda_graph_kv_indptr[: num_tokens + 1],
|
||||||
kv_indices=self.cuda_graph_kv_indices,
|
kv_indices=self.cuda_graph_kv_indices,
|
||||||
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
kv_len_arr=self.cuda_graph_kv_lens[:num_tokens],
|
||||||
backend="auto",
|
backend="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -196,9 +199,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
seq_lens,
|
seq_lens,
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
decode_wrapper=decode_wrapper,
|
decode_wrapper=decode_wrapper,
|
||||||
|
init_metadata_replay=False,
|
||||||
)
|
)
|
||||||
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
||||||
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
||||||
|
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||||
|
|
||||||
@@ -208,16 +213,30 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
seq_lens_cpu: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
|
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
||||||
|
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
||||||
|
kv_len_arr_cpu, dim=0
|
||||||
|
)
|
||||||
|
self.fast_decode_kwargs.update(
|
||||||
|
{
|
||||||
|
"qo_indptr_cpu": self.cuda_graph_qo_indptr_cpu[: bs + 1],
|
||||||
|
"kv_indptr_cpu": self.cuda_graph_kv_indptr_cpu[: bs + 1],
|
||||||
|
"kv_len_arr_cpu": kv_len_arr_cpu,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
seq_lens[:bs],
|
seq_lens[:bs],
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
||||||
|
init_metadata_replay=True,
|
||||||
|
**self.fast_decode_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||||
@@ -317,7 +336,6 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
|
|
||||||
# Buffers and wrappers
|
# Buffers and wrappers
|
||||||
self.kv_indptr = attn_backend.kv_indptr
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.q_indptr = attn_backend.q_indptr_decode
|
self.q_indptr = attn_backend.q_indptr_decode
|
||||||
|
|
||||||
@@ -327,6 +345,8 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
||||||
|
init_metadata_replay: bool = False,
|
||||||
|
**fast_decode_kwargs,
|
||||||
):
|
):
|
||||||
decode_wrapper = decode_wrapper or self.decode_wrapper
|
decode_wrapper = decode_wrapper or self.decode_wrapper
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
@@ -336,6 +356,8 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
self.q_indptr,
|
self.q_indptr,
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
|
init_metadata_replay,
|
||||||
|
**fast_decode_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
@@ -346,14 +368,19 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
q_indptr: torch.Tensor,
|
q_indptr: torch.Tensor,
|
||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
|
init_metadata_replay: bool = False,
|
||||||
|
**fast_decode_kwargs,
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
q_indptr = q_indptr[: bs + 1]
|
q_indptr = q_indptr[: bs + 1]
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
kv_indices = (
|
||||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
||||||
|
if not init_metadata_replay
|
||||||
|
else fast_decode_kwargs["kv_indices"]
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_lens = paged_kernel_lens.to(torch.int32)
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||||
sm_scale = self.scaling
|
sm_scale = self.scaling
|
||||||
|
|
||||||
@@ -366,21 +393,36 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.shape[1],
|
self.req_to_token.shape[1],
|
||||||
)
|
)
|
||||||
|
if not init_metadata_replay:
|
||||||
wrapper.plan(
|
wrapper.plan(
|
||||||
q_indptr,
|
q_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
kv_lens,
|
kv_lens,
|
||||||
self.num_local_heads,
|
self.num_local_heads,
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.qk_rope_head_dim,
|
self.qk_rope_head_dim,
|
||||||
1,
|
1,
|
||||||
False,
|
False,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
self.data_type,
|
self.data_type,
|
||||||
self.data_type,
|
self.data_type,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
wrapper.plan(
|
||||||
|
fast_decode_kwargs["qo_indptr_cpu"],
|
||||||
|
fast_decode_kwargs["kv_indptr_cpu"],
|
||||||
|
kv_indices,
|
||||||
|
fast_decode_kwargs["kv_len_arr_cpu"],
|
||||||
|
self.num_local_heads,
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
1,
|
||||||
|
False,
|
||||||
|
sm_scale,
|
||||||
|
self.data_type,
|
||||||
|
self.data_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMLAIndicesUpdaterPrefill:
|
class FlashInferMLAIndicesUpdaterPrefill:
|
||||||
@@ -400,7 +442,6 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
|
|
||||||
# Buffers and wrappers
|
# Buffers and wrappers
|
||||||
self.kv_indptr = attn_backend.kv_indptr
|
self.kv_indptr = attn_backend.kv_indptr
|
||||||
self.kv_last_page_len = attn_backend.kv_last_page_len
|
|
||||||
self.qo_indptr = attn_backend.qo_indptr
|
self.qo_indptr = attn_backend.qo_indptr
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
||||||
@@ -497,3 +538,42 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|||||||
self.q_data_type,
|
self.q_data_type,
|
||||||
self.data_type,
|
self.data_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_mla_decode_plan(
|
||||||
|
self,
|
||||||
|
qo_indptr_cpu: torch.Tensor,
|
||||||
|
kv_indptr_cpu: torch.Tensor,
|
||||||
|
kv_indices: torch.Tensor,
|
||||||
|
kv_len_arr_cpu: torch.Tensor,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim_ckv: int,
|
||||||
|
head_dim_kpe: int,
|
||||||
|
page_size: int,
|
||||||
|
causal: bool,
|
||||||
|
sm_scale: float,
|
||||||
|
q_data_type: torch.dtype,
|
||||||
|
kv_data_type: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
"""A faster version of BatchMLAPagedAttentionWrapper::plan,
|
||||||
|
for skipping the stream synchronization in original plan function during
|
||||||
|
cuda graph replaying.
|
||||||
|
"""
|
||||||
|
self._causal = causal
|
||||||
|
self._page_size = page_size
|
||||||
|
self._sm_scale = sm_scale
|
||||||
|
|
||||||
|
with self.device as device:
|
||||||
|
stream = torch.cuda.current_stream(device).cuda_stream
|
||||||
|
self._cached_module.plan(
|
||||||
|
self._float_workspace_buffer,
|
||||||
|
self._int_workspace_buffer,
|
||||||
|
self._pin_memory_int_workspace_buffer,
|
||||||
|
qo_indptr_cpu,
|
||||||
|
kv_indptr_cpu,
|
||||||
|
kv_len_arr_cpu,
|
||||||
|
num_heads,
|
||||||
|
head_dim_ckv,
|
||||||
|
causal,
|
||||||
|
stream,
|
||||||
|
)
|
||||||
|
|||||||
@@ -230,9 +230,10 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert encoder_lens is None, "Not supported"
|
assert encoder_lens is None, "Not supported"
|
||||||
|
|
||||||
@@ -308,9 +309,10 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
# NOTE: encoder_lens expected to be zeros or None
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
|
|||||||
@@ -582,6 +582,9 @@ class ScheduleBatch:
|
|||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# For decode
|
||||||
|
decode_seq_lens: List[int] = None
|
||||||
|
|
||||||
# For extend and mixed chunekd prefill
|
# For extend and mixed chunekd prefill
|
||||||
prefix_lens: List[int] = None
|
prefix_lens: List[int] = None
|
||||||
extend_lens: List[int] = None
|
extend_lens: List[int] = None
|
||||||
@@ -1168,8 +1171,10 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode_or_idle():
|
if self.forward_mode.is_decode_or_idle():
|
||||||
|
decode_seq_lens = self.seq_lens.cpu()
|
||||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||||
else:
|
else:
|
||||||
|
decode_seq_lens = None
|
||||||
extend_seq_lens = self.extend_lens
|
extend_seq_lens = self.extend_lens
|
||||||
extend_prefix_lens = self.prefix_lens
|
extend_prefix_lens = self.prefix_lens
|
||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
@@ -1194,6 +1199,7 @@ class ScheduleBatch:
|
|||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
global_num_tokens=self.global_num_tokens,
|
global_num_tokens=self.global_num_tokens,
|
||||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||||
|
decode_seq_lens=decode_seq_lens,
|
||||||
extend_num_tokens=self.extend_num_tokens,
|
extend_num_tokens=self.extend_num_tokens,
|
||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
extend_prefix_lens=extend_prefix_lens,
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
@@ -1267,6 +1273,9 @@ class ModelWorkerBatch:
|
|||||||
global_num_tokens: Optional[List[int]]
|
global_num_tokens: Optional[List[int]]
|
||||||
can_run_dp_cuda_graph: bool
|
can_run_dp_cuda_graph: bool
|
||||||
|
|
||||||
|
# For decode
|
||||||
|
decode_seq_lens: Optional[torch.Tensor]
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
extend_num_tokens: Optional[int]
|
extend_num_tokens: Optional[int]
|
||||||
extend_seq_lens: Optional[List[int]]
|
extend_seq_lens: Optional[List[int]]
|
||||||
|
|||||||
@@ -199,6 +199,10 @@ class CudaGraphRunner:
|
|||||||
if self.enable_torch_compile:
|
if self.enable_torch_compile:
|
||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
|
self.seq_lens_cpu = torch.full(
|
||||||
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
# Graph inputs
|
# Graph inputs
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
@@ -369,9 +373,9 @@ class CudaGraphRunner:
|
|||||||
num_tokens,
|
num_tokens,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
encoder_lens,
|
|
||||||
forward_batch.forward_mode,
|
forward_batch.forward_mode,
|
||||||
forward_batch.spec_info,
|
encoder_lens=encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
@@ -434,6 +438,7 @@ class CudaGraphRunner:
|
|||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.fill_(1)
|
self.seq_lens.fill_(1)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
self.seq_lens_cpu.fill_(1)
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
||||||
@@ -441,6 +446,8 @@ class CudaGraphRunner:
|
|||||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||||
|
if forward_batch.decode_seq_lens_cpu is not None:
|
||||||
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
@@ -456,9 +463,10 @@ class CudaGraphRunner:
|
|||||||
self.req_pool_indices,
|
self.req_pool_indices,
|
||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||||
self.encoder_lens,
|
|
||||||
forward_batch.forward_mode,
|
forward_batch.forward_mode,
|
||||||
forward_batch.spec_info,
|
encoder_lens=self.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
seq_lens_cpu=self.seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
|
|||||||
@@ -152,6 +152,9 @@ class ForwardBatch:
|
|||||||
# Position information
|
# Position information
|
||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
|
|
||||||
|
# For decode
|
||||||
|
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
extend_num_tokens: Optional[int] = None
|
extend_num_tokens: Optional[int] = None
|
||||||
extend_seq_lens: Optional[torch.Tensor] = None
|
extend_seq_lens: Optional[torch.Tensor] = None
|
||||||
@@ -256,6 +259,8 @@ class ForwardBatch:
|
|||||||
if ret.forward_mode.is_decode():
|
if ret.forward_mode.is_decode():
|
||||||
if ret.positions is None:
|
if ret.positions is None:
|
||||||
ret.positions = clamp_position(batch.seq_lens)
|
ret.positions = clamp_position(batch.seq_lens)
|
||||||
|
if ret.decode_seq_lens_cpu is None:
|
||||||
|
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
||||||
else:
|
else:
|
||||||
ret.extend_seq_lens = torch.tensor(
|
ret.extend_seq_lens = torch.tensor(
|
||||||
batch.extend_seq_lens, dtype=torch.int32
|
batch.extend_seq_lens, dtype=torch.int32
|
||||||
|
|||||||
Reference in New Issue
Block a user