Support target model verification in the attention backend (#2678)
Co-authored-by: yukavio <kavioyu@gmail.com>
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_token: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
|
||||
self.forward_metadata = None
|
||||
|
||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
ds_req_to_token,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
# TODO(Andy): Support CUDA graph for double sparse attention
|
||||
raise ValueError(
|
||||
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||
|
||||
self.cuda_graph_start_loc = torch.zeros(
|
||||
(max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.cuda_graph_attn_logits = torch.empty(
|
||||
(
|
||||
self.num_head,
|
||||
self.cuda_graph_max_total_num_tokens,
|
||||
),
|
||||
dtype=self.reduce_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens=None,
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_start_loc,
|
||||
self.cuda_graph_attn_logits,
|
||||
self.cuda_graph_max_seq_len,
|
||||
None,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens=None,
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
|
||||
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -18,12 +18,13 @@ import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
# Two wrappers: one for sliding window attention and one for full attention.
|
||||
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||
self.prefill_wrappers_paged = []
|
||||
self.prefill_wrappers_verify = []
|
||||
self.decode_wrappers = []
|
||||
for _ in range(self.num_wrappers):
|
||||
self.prefill_wrappers_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
self.prefill_wrappers_verify.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||
)
|
||||
self.decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
# Other metadata
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {}
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.seq_lens_sum,
|
||||
decode_wrappers=self.decode_wrappers,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrappers=self.prefill_wrappers_paged,
|
||||
use_ragged=False,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrappers_paged, False, False
|
||||
)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrappers=self.prefill_wrappers_verify,
|
||||
use_ragged=False,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrappers_verify, False, False
|
||||
)
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
prefill_wrappers=self.prefill_wrappers_paged,
|
||||
use_ragged=use_ragged,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||
@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||
]
|
||||
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
||||
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_token: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: torch.Tensor = None,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
decode_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
|
||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
|
||||
if forward_mode.is_decode():
|
||||
decode_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
decode_wrappers.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=self.decode_use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1],
|
||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token],
|
||||
)
|
||||
)
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
decode_wrappers=decode_wrappers,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
decode_wrappers=decode_wrappers,
|
||||
encoder_lens=encoder_lens,
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||
elif forward_mode.is_target_verify():
|
||||
prefill_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
prefill_wrappers.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
|
||||
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
|
||||
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
||||
custom_mask_buf=self.cuda_graph_custom_mask,
|
||||
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
||||
)
|
||||
)
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrappers=prefill_wrappers,
|
||||
use_ragged=False,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
|
||||
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: torch.Tensor = None,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
)
|
||||
if forward_mode.is_decode():
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
elif forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
||||
use_ragged=False,
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid forward mode")
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
self.call_begin_forward(
|
||||
@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum,
|
||||
self.kv_indptr[0],
|
||||
None,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def update_sliding_window(
|
||||
@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum_tmp,
|
||||
self.kv_indptr[wrapper_id],
|
||||
kv_start_idx_tmp,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum,
|
||||
self.kv_indptr[wrapper_id],
|
||||
kv_start_idx,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
if spec_info is None:
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
self.req_to_token,
|
||||
)
|
||||
|
||||
wrapper.end_forward()
|
||||
wrapper.begin_forward(
|
||||
@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.sliding_window_size = model_runner.sliding_window_size
|
||||
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.kv_indptr[0],
|
||||
self.qo_indptr[0],
|
||||
use_ragged,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def update_sliding_window(
|
||||
@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.kv_indptr[wrapper_id],
|
||||
self.qo_indptr[wrapper_id],
|
||||
use_ragged,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.kv_indptr[wrapper_id],
|
||||
self.qo_indptr[wrapper_id],
|
||||
use_ragged,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
if spec_info is None:
|
||||
# Normal extend
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices,
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
# extend part
|
||||
if use_ragged:
|
||||
@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.head_dim,
|
||||
1,
|
||||
q_data_type=self.q_data_type,
|
||||
custom_mask=custom_mask,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
"""Init the metadata for a forward pass."""
|
||||
pass
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
# TODO: Support CUDA graph
|
||||
raise ValueError(
|
||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
||||
)
|
||||
|
||||
def _run_sdpa_forward_extend(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
|
||||
class TritonAttnBackend(AttentionBackend):
|
||||
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_token: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens=None,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
assert encoder_lens is None, "Not supported"
|
||||
assert forward_mode.is_decode(), "Not supported"
|
||||
assert spec_info is None, "Not supported"
|
||||
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_attn_logits,
|
||||
None,
|
||||
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens=None,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
|
||||
@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.logits_processor import (
|
||||
LogitsMetadata,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorOutput,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -153,6 +153,10 @@ class CudaGraphRunner:
|
||||
if bs <= model_runner.req_to_token_pool.size
|
||||
and bs <= model_runner.server_args.cuda_graph_max_bs
|
||||
]
|
||||
|
||||
self.capture_forward_mode = ForwardMode.DECODE
|
||||
self.num_tokens_per_bs = 1
|
||||
|
||||
self.compile_bs = (
|
||||
[
|
||||
bs
|
||||
@@ -165,8 +169,8 @@ class CudaGraphRunner:
|
||||
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
@@ -179,12 +183,13 @@ class CudaGraphRunner:
|
||||
|
||||
# Common inputs
|
||||
with torch.device("cuda"):
|
||||
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
self.seq_lens = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||
)
|
||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
@@ -229,6 +234,9 @@ class CudaGraphRunner:
|
||||
self.model_runner.model.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if not forward_batch.forward_mode.is_cuda_graph():
|
||||
return False
|
||||
|
||||
if self.enable_dp_attention:
|
||||
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
||||
forward_batch.global_num_tokens
|
||||
@@ -258,12 +266,12 @@ class CudaGraphRunner:
|
||||
def capture(self):
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
capture_bs = (
|
||||
capture_range = (
|
||||
tqdm.tqdm(self.capture_bs)
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else self.capture_bs
|
||||
)
|
||||
for bs in capture_bs:
|
||||
for bs in capture_range:
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.compile_bs,
|
||||
@@ -283,12 +291,15 @@ class CudaGraphRunner:
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
num_token = bs * self.num_tokens_per_bs
|
||||
|
||||
# Common inputs
|
||||
input_ids = self.input_ids[:bs]
|
||||
input_ids = self.input_ids[:num_token]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:num_token]
|
||||
positions = self.positions[:num_token]
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
encoder_lens = self.encoder_lens[:bs]
|
||||
else:
|
||||
@@ -304,37 +315,41 @@ class CudaGraphRunner:
|
||||
global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=self.capture_forward_mode,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens_sum,
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * num_token,
|
||||
positions=positions,
|
||||
global_num_tokens=global_num_tokens,
|
||||
mrope_positions=mrope_positions,
|
||||
gathered_buffer=gathered_buffer,
|
||||
)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens_sum,
|
||||
encoder_lens=encoder_lens,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
global_num_tokens=global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
)
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
return logits_output.next_token_logits
|
||||
return logits_output.next_token_logits, logits_output.hidden_states
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
@@ -360,6 +375,9 @@ class CudaGraphRunner:
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
raw_bs = forward_batch.batch_size
|
||||
# In normal decoding case, raw_bs == raw_num_token
|
||||
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
|
||||
raw_num_token = forward_batch.input_ids.numel()
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention:
|
||||
@@ -374,10 +392,13 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
|
||||
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
positions = clamp_position(forward_batch.seq_lens)
|
||||
self.positions[:raw_num_token].copy_(positions)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
@@ -390,13 +411,18 @@ class CudaGraphRunner:
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||
self.encoder_lens,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
)
|
||||
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
next_token_logits = self.output_buffers[bs][:raw_bs]
|
||||
next_token_logits, hidden_states = self.output_buffers[bs]
|
||||
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits,
|
||||
next_token_logits=next_token_logits[:raw_num_token],
|
||||
hidden_states=(
|
||||
hidden_states[:raw_num_token] if hidden_states is not None else None
|
||||
),
|
||||
)
|
||||
return logits_output
|
||||
|
||||
@@ -96,7 +96,7 @@ class ForwardMode(IntEnum):
|
||||
return self == ForwardMode.DRAFT_EXTEND
|
||||
|
||||
def is_cuda_graph(self):
|
||||
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
|
||||
return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY
|
||||
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
Reference in New Issue
Block a user