595 lines
21 KiB
Python
595 lines
21 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Optional, Union
|
|
|
|
import torch
|
|
import triton
|
|
|
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
create_flashinfer_kv_indices_triton,
|
|
)
|
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
|
|
|
|
|
class TritonAttnBackend(AttentionBackend):
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
skip_prefill: bool = False,
|
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
|
):
|
|
# Lazy import to avoid the initialization of cuda context
|
|
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
|
decode_attention_fwd,
|
|
)
|
|
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
|
extend_attention_fwd,
|
|
)
|
|
|
|
super().__init__()
|
|
|
|
self.decode_attention_fwd = decode_attention_fwd
|
|
self.extend_attention_fwd = extend_attention_fwd
|
|
|
|
self.skip_prefill = skip_prefill
|
|
|
|
max_bs = model_runner.req_to_token_pool.size
|
|
|
|
if kv_indptr_buf is None:
|
|
self.kv_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
else:
|
|
self.kv_indptr = kv_indptr_buf
|
|
|
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
|
|
|
if not self.skip_prefill:
|
|
self.qo_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
|
)
|
|
|
|
self.mask_indptr = torch.zeros(
|
|
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
|
)
|
|
|
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
|
|
|
self.num_head = (
|
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
|
)
|
|
|
|
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
|
|
|
self.forward_metadata = None
|
|
|
|
self.max_context_len = model_runner.model_config.context_len
|
|
|
|
self.device = model_runner.device
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
"""Init auxiliary variables for triton attention backend."""
|
|
|
|
bs = forward_batch.batch_size
|
|
kv_indptr = self.kv_indptr
|
|
spec_info = forward_batch.spec_info
|
|
|
|
if forward_batch.forward_mode.is_decode_or_idle():
|
|
if spec_info is None:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.zeros(
|
|
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
|
)
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
|
bs = kv_indptr.shape[0] - 1
|
|
|
|
attn_logits = torch.zeros(
|
|
(
|
|
bs,
|
|
self.num_head,
|
|
self.num_kv_splits,
|
|
self.v_head_dim + 1,
|
|
),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
)
|
|
|
|
qo_indptr = None
|
|
custom_mask = None
|
|
mask_indptr = None
|
|
max_extend_len = None
|
|
elif forward_batch.forward_mode.is_target_verify():
|
|
bs = len(forward_batch.req_pool_indices)
|
|
qo_indptr = torch.arange(
|
|
0,
|
|
(1 + bs) * self.num_draft_tokens,
|
|
step=self.num_draft_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
# Different with flashinfer kv_indptr and kv_indices construction
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.zeros(
|
|
kv_indptr[-1], dtype=torch.int32, device=self.device
|
|
)
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
|
|
custom_mask = spec_info.custom_mask
|
|
seq_mask_len = self.num_draft_tokens * (
|
|
forward_batch.seq_lens + self.num_draft_tokens
|
|
)
|
|
mask_indptr = self.mask_indptr
|
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
|
mask_indptr = mask_indptr[: bs + 1]
|
|
max_extend_len = self.num_draft_tokens
|
|
attn_logits = None
|
|
elif forward_batch.forward_mode.is_draft_extend():
|
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
|
spec_info.generate_attn_arg_prefill(
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
None,
|
|
self.req_to_token,
|
|
)
|
|
)
|
|
mask_indptr = None
|
|
max_extend_len = torch.max(spec_info.accept_length).item()
|
|
attn_logits = None
|
|
else:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
|
forward_batch.extend_prefix_lens, dim=0
|
|
)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = torch.zeros(
|
|
forward_batch.extend_prefix_lens.sum().item(),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.extend_prefix_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
|
|
qo_indptr = self.qo_indptr
|
|
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
|
qo_indptr = qo_indptr[: bs + 1]
|
|
custom_mask = None
|
|
mask_indptr = None
|
|
attn_logits = None
|
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
|
|
|
self.forward_metadata = (
|
|
attn_logits,
|
|
max_extend_len,
|
|
kv_indptr,
|
|
kv_indices,
|
|
qo_indptr,
|
|
custom_mask,
|
|
mask_indptr,
|
|
)
|
|
|
|
def init_cuda_graph_state(
|
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
|
):
|
|
self.cuda_graph_attn_logits = torch.zeros(
|
|
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
|
dtype=torch.float32,
|
|
device=self.device,
|
|
)
|
|
if kv_indices_buf is None:
|
|
self.cuda_graph_kv_indices = torch.zeros(
|
|
(max_bs * self.max_context_len),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self.cuda_graph_kv_indices = kv_indices_buf
|
|
|
|
if not self.skip_prefill:
|
|
self.cuda_graph_custom_mask = torch.zeros(
|
|
(max_bs * self.max_context_len),
|
|
dtype=torch.uint8,
|
|
device=self.device,
|
|
)
|
|
|
|
def init_forward_metadata_capture_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
num_tokens: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
):
|
|
assert encoder_lens is None, "Not supported"
|
|
|
|
if forward_mode.is_decode_or_idle():
|
|
if spec_info is None:
|
|
kv_indptr = self.kv_indptr
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
|
|
|
attn_logits = self.cuda_graph_attn_logits
|
|
max_extend_len = None
|
|
qo_indptr = None
|
|
custom_mask = None
|
|
mask_indptr = None
|
|
elif forward_mode.is_target_verify():
|
|
qo_indptr = self.qo_indptr[: bs + 1]
|
|
qo_indptr[: bs + 1] = torch.arange(
|
|
0,
|
|
(1 + bs) * self.num_draft_tokens,
|
|
step=self.num_draft_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
kv_indptr = self.kv_indptr[: bs + 1]
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
|
|
custom_mask = self.cuda_graph_custom_mask
|
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
|
mask_indptr = self.mask_indptr[: bs + 1]
|
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
|
max_extend_len = self.num_draft_tokens
|
|
attn_logits = None
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
|
)
|
|
|
|
self.forward_metadata = (
|
|
attn_logits,
|
|
max_extend_len,
|
|
kv_indptr,
|
|
kv_indices,
|
|
qo_indptr,
|
|
custom_mask,
|
|
mask_indptr,
|
|
)
|
|
|
|
def init_forward_metadata_replay_cuda_graph(
|
|
self,
|
|
bs: int,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
seq_lens_sum: int,
|
|
encoder_lens: Optional[torch.Tensor],
|
|
forward_mode: ForwardMode,
|
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
# NOTE: encoder_lens expected to be zeros or None
|
|
if forward_mode.is_decode_or_idle():
|
|
# Update kv_indptr, kv_indices
|
|
kv_indptr = self.kv_indptr
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
if spec_info is None:
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
|
kv_indptr = kv_indptr[: bs + 1]
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices[:bs],
|
|
seq_lens[:bs],
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
else:
|
|
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
|
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
|
elif forward_mode.is_target_verify():
|
|
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
|
bs = len(req_pool_indices)
|
|
qo_indptr = self.qo_indptr[: bs + 1]
|
|
qo_indptr[: bs + 1] = torch.arange(
|
|
0,
|
|
(1 + bs) * self.num_draft_tokens,
|
|
step=self.num_draft_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
kv_indptr = self.kv_indptr[: bs + 1]
|
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
|
kv_indices = self.cuda_graph_kv_indices
|
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
self.req_to_token,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
kv_indptr,
|
|
None,
|
|
kv_indices,
|
|
self.req_to_token.stride(0),
|
|
)
|
|
custom_mask = self.cuda_graph_custom_mask
|
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
|
mask_indptr = self.mask_indptr[: bs + 1]
|
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
|
)
|
|
|
|
def get_cuda_graph_seq_len_fill_value(self):
|
|
return 1
|
|
|
|
def forward_extend(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
# TODO: reuse the buffer across layers
|
|
if layer.qk_head_dim != layer.v_head_dim:
|
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
else:
|
|
o = torch.empty_like(q)
|
|
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, v
|
|
)
|
|
|
|
(
|
|
_,
|
|
max_extend_len,
|
|
kv_indptr,
|
|
kv_indices,
|
|
qo_indptr,
|
|
custom_mask,
|
|
mask_indptr,
|
|
) = self.forward_metadata
|
|
|
|
self.extend_attention_fwd(
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
k.contiguous(),
|
|
v.contiguous(),
|
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
|
qo_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
custom_mask,
|
|
mask_indptr,
|
|
max_extend_len,
|
|
layer.scaling,
|
|
layer.logit_cap,
|
|
)
|
|
return o
|
|
|
|
def forward_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer: RadixAttention,
|
|
forward_batch: ForwardBatch,
|
|
save_kv_cache=True,
|
|
):
|
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
|
|
|
# TODO: reuse the buffer across layers
|
|
if layer.qk_head_dim != layer.v_head_dim:
|
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
|
else:
|
|
o = torch.empty_like(q)
|
|
|
|
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
|
|
|
if save_kv_cache:
|
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
layer, forward_batch.out_cache_loc, k, v
|
|
)
|
|
|
|
self.decode_attention_fwd(
|
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
|
kv_indptr,
|
|
kv_indices,
|
|
attn_logits,
|
|
self.num_kv_splits,
|
|
layer.scaling,
|
|
layer.logit_cap,
|
|
)
|
|
return o
|
|
|
|
|
|
class TritonMultiStepDraftBackend:
|
|
"""
|
|
Wrap multiple triton attention backends as one for multiple consecutive
|
|
draft decoding steps.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_runner: ModelRunner,
|
|
topk: int,
|
|
speculative_num_steps: int,
|
|
):
|
|
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
|
|
|
self.topk = topk
|
|
self.speculative_num_steps = speculative_num_steps
|
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
|
self.kv_indptr = torch.zeros(
|
|
(
|
|
self.speculative_num_steps,
|
|
max_bs + 1,
|
|
),
|
|
dtype=torch.int32,
|
|
device=model_runner.device,
|
|
)
|
|
self.attn_backends = []
|
|
for i in range(self.speculative_num_steps):
|
|
self.attn_backends.append(
|
|
TritonAttnBackend(
|
|
model_runner,
|
|
skip_prefill=True,
|
|
kv_indptr_buf=self.kv_indptr[i],
|
|
)
|
|
)
|
|
self.max_context_len = self.attn_backends[0].max_context_len
|
|
self.device = model_runner.device
|
|
# Cached variables for generate_draft_decode_kv_indices
|
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
|
|
|
def common_template(
|
|
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
|
):
|
|
num_seqs = forward_batch.batch_size
|
|
bs = self.topk * num_seqs
|
|
seq_lens_sum = forward_batch.seq_lens_sum
|
|
|
|
self.generate_draft_decode_kv_indices[
|
|
(self.speculative_num_steps, num_seqs, self.topk)
|
|
](
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.req_to_token_pool.req_to_token,
|
|
forward_batch.seq_lens,
|
|
kv_indices_buffer,
|
|
self.kv_indptr,
|
|
forward_batch.positions,
|
|
num_seqs,
|
|
self.topk,
|
|
self.pool_len,
|
|
kv_indices_buffer.shape[1],
|
|
self.kv_indptr.shape[1],
|
|
triton.next_power_of_2(num_seqs),
|
|
triton.next_power_of_2(self.speculative_num_steps),
|
|
triton.next_power_of_2(bs),
|
|
)
|
|
|
|
for i in range(self.speculative_num_steps):
|
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
|
]
|
|
call_fn(i, forward_batch)
|
|
|
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
kv_indices = torch.zeros(
|
|
(
|
|
self.speculative_num_steps,
|
|
forward_batch.batch_size * self.topk * self.max_context_len,
|
|
),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
|
|
def call_fn(i, forward_batch):
|
|
forward_batch.spec_info.kv_indptr = (
|
|
forward_batch.spec_info.kv_indptr.clone()
|
|
)
|
|
forward_batch.spec_info.kv_indices = (
|
|
forward_batch.spec_info.kv_indices.clone()
|
|
)
|
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
|
|
|
self.common_template(forward_batch, kv_indices, call_fn)
|
|
|
|
def init_cuda_graph_state(self, max_bs: int):
|
|
self.cuda_graph_kv_indices = torch.zeros(
|
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
for i in range(self.speculative_num_steps):
|
|
self.attn_backends[i].init_cuda_graph_state(
|
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
|
)
|
|
|
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
|
def call_fn(i, forward_batch):
|
|
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
|
forward_batch.batch_size,
|
|
forward_batch.batch_size * self.topk,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
encoder_lens=None,
|
|
forward_mode=ForwardMode.DECODE,
|
|
spec_info=forward_batch.spec_info,
|
|
)
|
|
|
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
|
|
|
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
|
def call_fn(i, forward_batch):
|
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
|
forward_batch.batch_size,
|
|
forward_batch.req_pool_indices,
|
|
forward_batch.seq_lens,
|
|
seq_lens_sum=-1,
|
|
encoder_lens=None,
|
|
forward_mode=ForwardMode.DECODE,
|
|
spec_info=forward_batch.spec_info,
|
|
seq_lens_cpu=None,
|
|
)
|
|
|
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|