Co-authored-by: Stefan He <11166516+hebiao064@users.noreply.github.com> Co-authored-by: Liangsheng Yin <95566987+hnyls2002@users.noreply.github.com> Co-authored-by: Baizhou Zhang <56809903+fridge003@users.noreply.github.com> Co-authored-by: DarkSharpness <76582120+darksharpness@users.noreply.github.com> Co-authored-by: ZhengdQin <46387172+zhengdqin@users.noreply.github.com> Co-authored-by: DarkSharpness <2040703891@qq.com> Co-authored-by: hnyls2002 <lsyincs@gmail.com> Co-authored-by: Zhengda Qin <zhengdqin@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: HAI <hixiao@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
|
|
|
|
class HybridAttnBackend(AttentionBackend):
|
|
"""Support different backends for prefill and decode."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
prefill_backend: AttentionBackend,
|
|
decode_backend: AttentionBackend,
|
|
):
|
|
self.model_runner = model_runner
|
|
self.prefill_backend = prefill_backend
|
|
self.decode_backend = decode_backend
|
|
self.data_type = model_runner.kv_cache_dtype
|
|
|
|
def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend:
|
|
"""
|
|
Select the appropriate attention backend based on the forward mode.
|
|
|
|
Args:
|
|
forward_mode: The current forward mode indicating the operation type
|
|
|
|
Returns:
|
|
The selected attention backend (prefill or decode)
|
|
|
|
Note:
|
|
- decode_or_idle: Always uses decode backend
|
|
- target_verify or draft_extend: Uses decode backend if speculative_attention_mode is "decode", otherwise prefill backend
|
|
- prefill: Always uses prefill backend
|
|
"""
|
|
if forward_mode.is_decode_or_idle():
|
|
return self.decode_backend
|
|
elif forward_mode.is_target_verify() or forward_mode.is_draft_extend():
|
|
return (
|
|
self.decode_backend
|
|
if self.model_runner.server_args.speculative_attention_mode == "decode"
|
|
else self.prefill_backend
|
|
)
|
|
else:
|
|
return self.prefill_backend
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
backend = self._select_backend(forward_batch.forward_mode)
|
|
backend.init_forward_metadata(forward_batch)
|
|
|
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
|
if (
|
|
self.model_runner.server_args.speculative_algorithm is not None
|
|
and self.model_runner.server_args.speculative_attention_mode == "prefill"
|
|
):
|
|
# When speculative decoding is enabled, we need to initialize the backend
|
|
# that will be used for target_verify.
|
|
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInput],
|
|
):
|
|
backend = self._select_backend(forward_mode)
|
|
backend.init_forward_metadata_capture_cuda_graph(
|
|
bs,
|
|
num_tokens,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
encoder_lens,
|
|
forward_mode,
|
|
spec_info,
|
|
)
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[SpecInput],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
backend = self._select_backend(forward_mode)
|
|
backend.init_forward_metadata_replay_cuda_graph(
|
|
bs,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
seq_lens_sum,
|
|
encoder_lens,
|
|
forward_mode,
|
|
spec_info,
|
|
seq_lens_cpu,
|
|
)
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
**kwargs,
|
|
):
|
|
return self.decode_backend.forward_decode(
|
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
|
)
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache: bool = True,
|
|
**kwargs,
|
|
):
|
|
backend = self._select_backend(forward_batch.forward_mode)
|
|
return backend.forward_extend(
|
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
|
)
|
|
|
|
def get_indexer_metadata(
|
|
self, layer_id: int, forward_batch: ForwardBatch
|
|
) -> Optional[BaseIndexerMetadata]:
|
|
backend = self._select_backend(forward_batch.forward_mode)
|
|
return backend.get_indexer_metadata(layer_id, forward_batch)
|