Support target model verification in the attention backend (#2678)
Co-authored-by: yukavio <kavioyu@gmail.com>
This commit is contained in:
@@ -1,10 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
|
|
||||||
class AttentionBackend(ABC):
|
class AttentionBackend(ABC):
|
||||||
@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
|
|||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
|
num_token: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor] = None,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor] = None,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
|
|
||||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|||||||
ds_req_to_token,
|
ds_req_to_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
|
||||||
# TODO(Andy): Support CUDA graph for double sparse attention
|
|
||||||
raise ValueError(
|
|
||||||
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
|
||||||
)
|
|
||||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
|
||||||
|
|
||||||
self.cuda_graph_start_loc = torch.zeros(
|
|
||||||
(max_bs,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
self.cuda_graph_attn_logits = torch.empty(
|
|
||||||
(
|
|
||||||
self.num_head,
|
|
||||||
self.cuda_graph_max_total_num_tokens,
|
|
||||||
),
|
|
||||||
dtype=self.reduce_dtype,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
|
||||||
self,
|
|
||||||
bs: int,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
encoder_lens=None,
|
|
||||||
):
|
|
||||||
# NOTE: encoder_lens expected to be zeros or None
|
|
||||||
self.forward_metadata = (
|
|
||||||
self.cuda_graph_start_loc,
|
|
||||||
self.cuda_graph_attn_logits,
|
|
||||||
self.cuda_graph_max_seq_len,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
|
||||||
self,
|
|
||||||
bs: int,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_sum: int,
|
|
||||||
encoder_lens=None,
|
|
||||||
):
|
|
||||||
# NOTE: encoder_lens expected to be zeros or None
|
|
||||||
self.cuda_graph_start_loc.zero_()
|
|
||||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import TYPE_CHECKING, List, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -18,12 +18,13 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
# Two wrappers: one for sliding window attention and one for full attention.
|
# Two wrappers: one for sliding window attention and one for full attention.
|
||||||
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||||
self.prefill_wrappers_paged = []
|
self.prefill_wrappers_paged = []
|
||||||
|
self.prefill_wrappers_verify = []
|
||||||
self.decode_wrappers = []
|
self.decode_wrappers = []
|
||||||
for _ in range(self.num_wrappers):
|
for _ in range(self.num_wrappers):
|
||||||
self.prefill_wrappers_paged.append(
|
self.prefill_wrappers_paged.append(
|
||||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
)
|
)
|
||||||
|
self.prefill_wrappers_verify.append(
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
|
)
|
||||||
self.decode_wrappers.append(
|
self.decode_wrappers.append(
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
# Other metadata
|
# Other metadata
|
||||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
|
self.prefill_cuda_graph_metadata = {}
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
decode_wrappers=self.decode_wrappers,
|
decode_wrappers=self.decode_wrappers,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
||||||
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrappers=self.prefill_wrappers_paged,
|
||||||
|
use_ragged=False,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
self.forward_metadata = PrefillMetadata(
|
||||||
|
self.prefill_wrappers_paged, False, False
|
||||||
|
)
|
||||||
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrappers=self.prefill_wrappers_verify,
|
||||||
|
use_ragged=False,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
self.forward_metadata = PrefillMetadata(
|
||||||
|
self.prefill_wrappers_verify, False, False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
prefill_wrappers=self.prefill_wrappers_paged,
|
prefill_wrappers=self.prefill_wrappers_paged,
|
||||||
use_ragged=use_ragged,
|
use_ragged=use_ragged,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=None,
|
||||||
)
|
)
|
||||||
self.forward_metadata = PrefillMetadata(
|
self.forward_metadata = PrefillMetadata(
|
||||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||||
@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
|
(max_bs * self.max_context_len),
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
|
||||||
|
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
|
num_token: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: torch.Tensor = None,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
decode_wrappers = []
|
if forward_mode.is_decode():
|
||||||
for i in range(self.num_wrappers):
|
decode_wrappers = []
|
||||||
decode_wrappers.append(
|
for i in range(self.num_wrappers):
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
decode_wrappers.append(
|
||||||
self.workspace_buffer,
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
"NHD",
|
self.workspace_buffer,
|
||||||
use_cuda_graph=True,
|
"NHD",
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
use_cuda_graph=True,
|
||||||
paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1],
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1],
|
||||||
paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||||
|
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
self.indices_updater_decode.update(
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
decode_wrappers=decode_wrappers,
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
|
spec_info=spec_info,
|
||||||
)
|
)
|
||||||
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
||||||
self.indices_updater_decode.update(
|
elif forward_mode.is_target_verify():
|
||||||
req_pool_indices,
|
prefill_wrappers = []
|
||||||
seq_lens,
|
for i in range(self.num_wrappers):
|
||||||
seq_lens_sum,
|
prefill_wrappers.append(
|
||||||
decode_wrappers=decode_wrappers,
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
encoder_lens=encoder_lens,
|
self.workspace_buffer,
|
||||||
)
|
"NHD",
|
||||||
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
use_cuda_graph=True,
|
||||||
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
|
||||||
|
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
|
||||||
|
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
||||||
|
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
||||||
|
custom_mask_buf=self.cuda_graph_custom_mask,
|
||||||
|
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrappers=prefill_wrappers,
|
||||||
|
use_ragged=False,
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
|
spec_info=spec_info,
|
||||||
|
)
|
||||||
|
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
|
||||||
|
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: torch.Tensor = None,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
self.indices_updater_decode.update(
|
if forward_mode.is_decode():
|
||||||
req_pool_indices[:bs],
|
self.indices_updater_decode.update(
|
||||||
seq_lens[:bs],
|
req_pool_indices[:bs],
|
||||||
seq_lens_sum,
|
seq_lens[:bs],
|
||||||
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
seq_lens_sum,
|
||||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
||||||
)
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
|
spec_info=spec_info,
|
||||||
|
)
|
||||||
|
elif forward_mode.is_target_verify():
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
req_pool_indices[:bs],
|
||||||
|
seq_lens[:bs],
|
||||||
|
seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
||||||
|
use_ragged=False,
|
||||||
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
|
spec_info=spec_info,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid forward mode")
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.data_type = model_runner.kv_cache_dtype
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
self.q_data_type = model_runner.dtype
|
self.q_data_type = model_runner.dtype
|
||||||
self.sliding_window_size = model_runner.sliding_window_size
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
# Buffers and wrappers
|
# Buffers and wrappers
|
||||||
@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
self.kv_indptr[0],
|
self.kv_indptr[0],
|
||||||
None,
|
None,
|
||||||
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(
|
def update_sliding_window(
|
||||||
@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
paged_kernel_lens_sum_tmp,
|
paged_kernel_lens_sum_tmp,
|
||||||
self.kv_indptr[wrapper_id],
|
self.kv_indptr[wrapper_id],
|
||||||
kv_start_idx_tmp,
|
kv_start_idx_tmp,
|
||||||
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cross_attention(
|
def update_cross_attention(
|
||||||
@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
self.kv_indptr[wrapper_id],
|
self.kv_indptr[wrapper_id],
|
||||||
kv_start_idx,
|
kv_start_idx,
|
||||||
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
kv_start_idx: torch.Tensor,
|
kv_start_idx: torch.Tensor,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
if spec_info is None:
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
bs = len(req_pool_indices)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indices = torch.empty(
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
kv_indices = torch.empty(
|
||||||
)
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_start_idx,
|
kv_start_idx,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
self.req_to_token.shape[1],
|
self.req_to_token.shape[1],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
self.req_to_token,
|
||||||
|
)
|
||||||
|
|
||||||
wrapper.end_forward()
|
wrapper.end_forward()
|
||||||
wrapper.begin_forward(
|
wrapper.begin_forward(
|
||||||
@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.data_type = model_runner.kv_cache_dtype
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
self.q_data_type = model_runner.dtype
|
self.q_data_type = model_runner.dtype
|
||||||
self.sliding_window_size = model_runner.sliding_window_size
|
self.sliding_window_size = model_runner.sliding_window_size
|
||||||
|
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
# Buffers and wrappers
|
# Buffers and wrappers
|
||||||
@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.kv_indptr[0],
|
self.kv_indptr[0],
|
||||||
self.qo_indptr[0],
|
self.qo_indptr[0],
|
||||||
use_ragged,
|
use_ragged,
|
||||||
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(
|
def update_sliding_window(
|
||||||
@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.kv_indptr[wrapper_id],
|
self.kv_indptr[wrapper_id],
|
||||||
self.qo_indptr[wrapper_id],
|
self.qo_indptr[wrapper_id],
|
||||||
use_ragged,
|
use_ragged,
|
||||||
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cross_attention(
|
def update_cross_attention(
|
||||||
@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: torch.Tensor,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.kv_indptr[wrapper_id],
|
self.kv_indptr[wrapper_id],
|
||||||
self.qo_indptr[wrapper_id],
|
self.qo_indptr[wrapper_id],
|
||||||
use_ragged,
|
use_ragged,
|
||||||
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
qo_indptr: torch.Tensor,
|
qo_indptr: torch.Tensor,
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
if spec_info is None:
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
# Normal extend
|
||||||
kv_indices = torch.empty(
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
)
|
kv_indices = torch.empty(
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||||
self.req_to_token,
|
)
|
||||||
req_pool_indices,
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
paged_kernel_lens,
|
self.req_to_token,
|
||||||
kv_indptr,
|
req_pool_indices,
|
||||||
kv_start_idx,
|
paged_kernel_lens,
|
||||||
kv_indices,
|
kv_indptr,
|
||||||
self.req_to_token.shape[1],
|
kv_start_idx,
|
||||||
)
|
kv_indices,
|
||||||
|
self.req_to_token.shape[1],
|
||||||
|
)
|
||||||
|
|
||||||
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
qo_indptr = qo_indptr[: bs + 1]
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
custom_mask = None
|
||||||
|
else:
|
||||||
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
|
spec_info.generate_attn_arg_prefill(
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
self.req_to_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# extend part
|
# extend part
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
1,
|
1,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
|
custom_mask=custom_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
|
||||||
# TODO: Support CUDA graph
|
|
||||||
raise ValueError(
|
|
||||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
|
||||||
self,
|
|
||||||
bs: int,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
encoder_lens: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# TODO: Support CUDA graph
|
|
||||||
raise ValueError(
|
|
||||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
|
||||||
self,
|
|
||||||
bs: int,
|
|
||||||
req_pool_indices: torch.Tensor,
|
|
||||||
seq_lens: torch.Tensor,
|
|
||||||
seq_lens_sum: int,
|
|
||||||
encoder_lens: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
# TODO: Support CUDA graph
|
|
||||||
raise ValueError(
|
|
||||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
|
||||||
# TODO: Support CUDA graph
|
|
||||||
raise ValueError(
|
|
||||||
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_sdpa_forward_extend(
|
def _run_sdpa_forward_extend(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
|
|
||||||
class TritonAttnBackend(AttentionBackend):
|
class TritonAttnBackend(AttentionBackend):
|
||||||
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
|
num_token: int,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens=None,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
# NOTE: encoder_lens expected to be zeros or None
|
assert encoder_lens is None, "Not supported"
|
||||||
|
assert forward_mode.is_decode(), "Not supported"
|
||||||
|
assert spec_info is None, "Not supported"
|
||||||
|
|
||||||
self.forward_metadata = (
|
self.forward_metadata = (
|
||||||
self.cuda_graph_attn_logits,
|
self.cuda_graph_attn_logits,
|
||||||
None,
|
None,
|
||||||
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens=None,
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
# NOTE: encoder_lens expected to be zeros or None
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
self.cuda_graph_start_loc.zero_()
|
self.cuda_graph_start_loc.zero_()
|
||||||
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
k: torch.Tensor,
|
||||||
v,
|
v: torch.Tensor,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
|
|||||||
@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import (
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
LogitsMetadata,
|
|
||||||
LogitsProcessor,
|
|
||||||
LogitsProcessorOutput,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
CaptureHiddenMode,
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
)
|
||||||
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -153,6 +153,10 @@ class CudaGraphRunner:
|
|||||||
if bs <= model_runner.req_to_token_pool.size
|
if bs <= model_runner.req_to_token_pool.size
|
||||||
and bs <= model_runner.server_args.cuda_graph_max_bs
|
and bs <= model_runner.server_args.cuda_graph_max_bs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.capture_forward_mode = ForwardMode.DECODE
|
||||||
|
self.num_tokens_per_bs = 1
|
||||||
|
|
||||||
self.compile_bs = (
|
self.compile_bs = (
|
||||||
[
|
[
|
||||||
bs
|
bs
|
||||||
@@ -165,8 +169,8 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
||||||
self.seq_len_fill_value = (
|
self.seq_len_fill_value = (
|
||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
)
|
)
|
||||||
@@ -179,12 +183,13 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
||||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
self.seq_lens = torch.full(
|
self.seq_lens = torch.full(
|
||||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
||||||
)
|
)
|
||||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
|
||||||
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
@@ -229,6 +234,9 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.model.capture_mode = False
|
self.model_runner.model.capture_mode = False
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
if not forward_batch.forward_mode.is_cuda_graph():
|
||||||
|
return False
|
||||||
|
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
||||||
forward_batch.global_num_tokens
|
forward_batch.global_num_tokens
|
||||||
@@ -258,12 +266,12 @@ class CudaGraphRunner:
|
|||||||
def capture(self):
|
def capture(self):
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
self.stream = graph_capture_context.stream
|
self.stream = graph_capture_context.stream
|
||||||
capture_bs = (
|
capture_range = (
|
||||||
tqdm.tqdm(self.capture_bs)
|
tqdm.tqdm(self.capture_bs)
|
||||||
if get_tensor_model_parallel_rank() == 0
|
if get_tensor_model_parallel_rank() == 0
|
||||||
else self.capture_bs
|
else self.capture_bs
|
||||||
)
|
)
|
||||||
for bs in capture_bs:
|
for bs in capture_range:
|
||||||
with patch_model(
|
with patch_model(
|
||||||
self.model_runner.model,
|
self.model_runner.model,
|
||||||
bs in self.compile_bs,
|
bs in self.compile_bs,
|
||||||
@@ -283,12 +291,15 @@ class CudaGraphRunner:
|
|||||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
|
num_token = bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
input_ids = self.input_ids[:bs]
|
input_ids = self.input_ids[:num_token]
|
||||||
req_pool_indices = self.req_pool_indices[:bs]
|
req_pool_indices = self.req_pool_indices[:bs]
|
||||||
seq_lens = self.seq_lens[:bs]
|
seq_lens = self.seq_lens[:bs]
|
||||||
out_cache_loc = self.out_cache_loc[:bs]
|
out_cache_loc = self.out_cache_loc[:num_token]
|
||||||
|
positions = self.positions[:num_token]
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
encoder_lens = self.encoder_lens[:bs]
|
encoder_lens = self.encoder_lens[:bs]
|
||||||
else:
|
else:
|
||||||
@@ -304,37 +315,41 @@ class CudaGraphRunner:
|
|||||||
global_num_tokens = None
|
global_num_tokens = None
|
||||||
gathered_buffer = None
|
gathered_buffer = None
|
||||||
|
|
||||||
|
forward_batch = ForwardBatch(
|
||||||
|
forward_mode=self.capture_forward_mode,
|
||||||
|
batch_size=bs,
|
||||||
|
input_ids=input_ids,
|
||||||
|
req_pool_indices=req_pool_indices,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
|
attn_backend=self.model_runner.attn_backend,
|
||||||
|
out_cache_loc=out_cache_loc,
|
||||||
|
seq_lens_sum=seq_lens_sum,
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
|
return_logprob=False,
|
||||||
|
top_logprobs_nums=[0] * num_token,
|
||||||
|
positions=positions,
|
||||||
|
global_num_tokens=global_num_tokens,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
|
gathered_buffer=gathered_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
|
num_token,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
encoder_lens,
|
encoder_lens,
|
||||||
|
forward_batch.forward_mode,
|
||||||
|
forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
def run_once():
|
def run_once():
|
||||||
forward_batch = ForwardBatch(
|
|
||||||
forward_mode=ForwardMode.DECODE,
|
|
||||||
batch_size=bs,
|
|
||||||
input_ids=input_ids,
|
|
||||||
req_pool_indices=req_pool_indices,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
|
||||||
attn_backend=self.model_runner.attn_backend,
|
|
||||||
out_cache_loc=out_cache_loc,
|
|
||||||
seq_lens_sum=seq_lens_sum,
|
|
||||||
encoder_lens=encoder_lens,
|
|
||||||
return_logprob=False,
|
|
||||||
top_logprobs_nums=[0] * bs,
|
|
||||||
positions=clamp_position(seq_lens),
|
|
||||||
mrope_positions=mrope_positions,
|
|
||||||
global_num_tokens=global_num_tokens,
|
|
||||||
gathered_buffer=gathered_buffer,
|
|
||||||
)
|
|
||||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||||
return logits_output.next_token_logits
|
return logits_output.next_token_logits, logits_output.hidden_states
|
||||||
|
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@@ -360,6 +375,9 @@ class CudaGraphRunner:
|
|||||||
def replay(self, forward_batch: ForwardBatch):
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
assert forward_batch.out_cache_loc is not None
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
|
# In normal decoding case, raw_bs == raw_num_token
|
||||||
|
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
|
||||||
|
raw_num_token = forward_batch.input_ids.numel()
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
@@ -374,10 +392,13 @@ class CudaGraphRunner:
|
|||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
|
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||||
|
positions = clamp_position(forward_batch.seq_lens)
|
||||||
|
self.positions[:raw_num_token].copy_(positions)
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
if forward_batch.mrope_positions is not None:
|
if forward_batch.mrope_positions is not None:
|
||||||
@@ -390,13 +411,18 @@ class CudaGraphRunner:
|
|||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
forward_batch.seq_lens_sum + (bs - raw_bs),
|
forward_batch.seq_lens_sum + (bs - raw_bs),
|
||||||
self.encoder_lens,
|
self.encoder_lens,
|
||||||
|
forward_batch.forward_mode,
|
||||||
|
forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
next_token_logits = self.output_buffers[bs][:raw_bs]
|
next_token_logits, hidden_states = self.output_buffers[bs]
|
||||||
|
|
||||||
logits_output = LogitsProcessorOutput(
|
logits_output = LogitsProcessorOutput(
|
||||||
next_token_logits=next_token_logits,
|
next_token_logits=next_token_logits[:raw_num_token],
|
||||||
|
hidden_states=(
|
||||||
|
hidden_states[:raw_num_token] if hidden_states is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return logits_output
|
return logits_output
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class ForwardMode(IntEnum):
|
|||||||
return self == ForwardMode.DRAFT_EXTEND
|
return self == ForwardMode.DRAFT_EXTEND
|
||||||
|
|
||||||
def is_cuda_graph(self):
|
def is_cuda_graph(self):
|
||||||
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY)
|
return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY
|
||||||
|
|
||||||
def is_dummy_first(self):
|
def is_dummy_first(self):
|
||||||
return self == ForwardMode.DUMMY_FIRST
|
return self == ForwardMode.DUMMY_FIRST
|
||||||
|
|||||||
Reference in New Issue
Block a user