Files
sglang/python/sglang/srt/layers/attention/hybrid_attn_backend.py

101 lines
3.1 KiB
Python

from typing import Optional, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class HybridAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
def __init__(
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
):
self.prefill_backend = prefill_backend
self.decode_backend = decode_backend
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
self.decode_backend.init_forward_metadata(forward_batch)
else:
self.prefill_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.decode_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.decode_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self):
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
return self.decode_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
return self.prefill_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)