Support cuda graph in the triton attention backend (#1401)
This commit is contained in:
@@ -36,14 +36,41 @@ class AttentionBackend(ABC):
|
||||
def init_forward_metadata(
|
||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||
):
|
||||
pass
|
||||
"""Init the metadata for a forward pass."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, q, k, v, layer, input_metadata: InputMetadata):
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
"""Init the global shared states for cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self, bs: int, req_pool_indices, seq_lens
|
||||
):
|
||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, bs: int, req_pool_indices, seq_lens
|
||||
):
|
||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
"""Run forward on an attention layer."""
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
return self.forward_decode(q, k, v, layer, input_metadata)
|
||||
else:
|
||||
return self.forward_extend(q, k, v, layer, input_metadata)
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlashInferAttnBackend(AttentionBackend):
|
||||
"""Flashinfer attention kernels."""
|
||||
@@ -153,7 +180,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_kv_indices.clone(),
|
||||
]
|
||||
|
||||
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self, bs: int, req_pool_indices, seq_lens
|
||||
):
|
||||
if self.model_runner.sliding_window_size is None:
|
||||
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
@@ -194,7 +223,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
self.forward_metadata = (False, None, decode_wrapper)
|
||||
|
||||
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, bs: int, req_pool_indices, seq_lens
|
||||
):
|
||||
update_flashinfer_indices(
|
||||
ForwardMode.DECODE,
|
||||
self.model_runner,
|
||||
@@ -204,6 +235,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_metadata[bs],
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
if not isinstance(self.prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = self.prefill_wrapper_paged
|
||||
@@ -290,6 +324,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
from sglang.srt.layers.triton_attention.decode_attention import (
|
||||
REDUCE_TORCH_TYPE,
|
||||
decode_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.triton_attention.extend_attention import (
|
||||
@@ -300,29 +335,78 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
self.decode_attention_fwd = decode_attention_fwd
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE
|
||||
self.num_head = model_runner.model_config.num_attention_heads
|
||||
|
||||
self.forward_metadata = None
|
||||
|
||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||
|
||||
def init_forward_metadata(
|
||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||
):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
if input_metadata.forward_mode.is_decode():
|
||||
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
||||
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
||||
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
||||
|
||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||
attn_logits = torch.empty(
|
||||
(self.num_head, total_num_tokens),
|
||||
dtype=self.REDUCE_TORCH_TYPE,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
||||
max_extend_len = None
|
||||
else:
|
||||
start_loc = max_seq_len = total_num_tokens = None
|
||||
start_loc = attn_logits = max_seq_len = None
|
||||
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
||||
|
||||
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens
|
||||
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||
|
||||
self.cuda_graph_start_loc = torch.zeros(
|
||||
(max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.cuda_graph_attn_logits = torch.empty(
|
||||
(self.num_head, self.cuda_graph_max_total_num_tokens),
|
||||
dtype=self.REDUCE_TORCH_TYPE,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self, bs: int, req_pool_indices, seq_lens
|
||||
):
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_start_loc,
|
||||
self.cuda_graph_attn_logits,
|
||||
self.cuda_graph_max_seq_len,
|
||||
None,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, bs: int, req_pool_indices, seq_lens
|
||||
):
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_start_loc,
|
||||
self.cuda_graph_attn_logits,
|
||||
self.cuda_graph_max_seq_len,
|
||||
None,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
@@ -332,8 +416,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
||||
|
||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k.contiguous(),
|
||||
@@ -350,16 +433,16 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
||||
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||
@@ -374,10 +457,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
input_metadata.req_pool_indices,
|
||||
start_loc,
|
||||
input_metadata.seq_lens,
|
||||
attn_logits,
|
||||
max_seq_len,
|
||||
total_num_tokens,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
@@ -66,18 +66,18 @@ class FlashinferUpdater:
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
self.batch_size = len(req_pool_indices)
|
||||
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
self.decode_wrapper = (
|
||||
decode_wrapper or self.model_runner.attn_backend.decode_wrapper
|
||||
)
|
||||
self.prefill_wrapper_ragged = (
|
||||
self.model_runner.attn_backend.prefill_wrapper_ragged
|
||||
)
|
||||
self.prefill_wrapper_paged = (
|
||||
self.model_runner.attn_backend.prefill_wrapper_paged
|
||||
)
|
||||
|
||||
(
|
||||
self.decode_wrapper,
|
||||
self.prefill_wrapper_ragged,
|
||||
self.prefill_wrapper_paged,
|
||||
) = (
|
||||
decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
|
||||
self.model_runner.attn_backend.prefill_wrapper_ragged,
|
||||
self.model_runner.attn_backend.prefill_wrapper_paged,
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
def _init_indices_no_sliding_window(self):
|
||||
|
||||
@@ -114,7 +114,7 @@ def _fwd_kernel_stage1(
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Logics,
|
||||
logits,
|
||||
V_Buffer,
|
||||
Out,
|
||||
Req_to_tokens,
|
||||
@@ -162,7 +162,7 @@ def _fwd_kernel_stage2(
|
||||
)
|
||||
|
||||
qk = tl.load(
|
||||
Logics
|
||||
logits
|
||||
+ cur_head * stride_logic_h
|
||||
+ (cur_batch_start_loc + start_n + offs_n),
|
||||
mask=start_n + offs_n < cur_batch_seq_len,
|
||||
@@ -238,7 +238,7 @@ def _decode_att_m_fwd(
|
||||
|
||||
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logics,
|
||||
logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
@@ -247,9 +247,9 @@ def _decode_softmax_reducev_fwd(
|
||||
b_seq_len,
|
||||
):
|
||||
BLOCK = 64
|
||||
batch, head = b_seq_len.shape[0], logics.shape[0]
|
||||
batch, head = b_seq_len.shape[0], logits.shape[0]
|
||||
grid = (batch, head, 1)
|
||||
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
||||
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
||||
|
||||
num_warps = 1
|
||||
|
||||
@@ -257,14 +257,14 @@ def _decode_softmax_reducev_fwd(
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||
|
||||
_fwd_kernel_stage2[grid](
|
||||
logics,
|
||||
logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
logics.stride(0),
|
||||
logits.stride(0),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
o.stride(0),
|
||||
@@ -387,7 +387,7 @@ def _fwd_grouped_kernel_stage1(
|
||||
|
||||
@triton.jit
|
||||
def _fwd_grouped_kernel_stage2(
|
||||
Logics,
|
||||
logits,
|
||||
V_Buffer,
|
||||
Out,
|
||||
Req_to_tokens,
|
||||
@@ -443,7 +443,7 @@ def _fwd_grouped_kernel_stage2(
|
||||
)
|
||||
|
||||
qk = tl.load(
|
||||
Logics + offs_qk,
|
||||
logits + offs_qk,
|
||||
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
|
||||
other=float("-inf"),
|
||||
)
|
||||
@@ -531,7 +531,7 @@ def _decode_grouped_att_m_fwd(
|
||||
|
||||
|
||||
def _decode_grouped_softmax_reducev_fwd(
|
||||
logics,
|
||||
logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
@@ -540,8 +540,8 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
b_seq_len,
|
||||
):
|
||||
BLOCK = 128
|
||||
batch, head_num = b_seq_len.shape[0], logics.shape[0]
|
||||
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
||||
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
||||
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
||||
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
||||
|
||||
@@ -551,14 +551,14 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||
|
||||
_fwd_grouped_kernel_stage2[grid](
|
||||
logics,
|
||||
logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
logics.stride(0),
|
||||
logits.stride(0),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
o.stride(0),
|
||||
@@ -584,17 +584,11 @@ def decode_attention_fwd(
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
total_num_tokens,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
att_m=None,
|
||||
):
|
||||
if att_m is None:
|
||||
att_m = torch.empty(
|
||||
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
||||
)
|
||||
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
||||
|
||||
if kv_group_num == 1:
|
||||
@@ -602,7 +596,7 @@ def decode_attention_fwd(
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_m,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
@@ -612,7 +606,7 @@ def decode_attention_fwd(
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
att_m,
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
@@ -625,7 +619,7 @@ def decode_attention_fwd(
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_m,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
@@ -635,7 +629,7 @@ def decode_attention_fwd(
|
||||
logit_cap,
|
||||
)
|
||||
_decode_grouped_softmax_reducev_fwd(
|
||||
att_m,
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -17,13 +19,12 @@ limitations under the License.
|
||||
|
||||
import bisect
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||
from sglang.srt.layers.logits_processor import (
|
||||
LogitsMetadata,
|
||||
LogitsProcessor,
|
||||
@@ -35,6 +36,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||
for sub in model._modules.values():
|
||||
@@ -111,7 +115,7 @@ class CudaGraphRunner:
|
||||
self.req_pool_indices = torch.zeros(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.position_ids_offsets = torch.ones(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
@@ -121,6 +125,9 @@ class CudaGraphRunner:
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# Sampling info
|
||||
vocab_size = model_runner.model_config.vocab_size
|
||||
@@ -176,7 +183,7 @@ class CudaGraphRunner:
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.capture_cuda_graph_init(
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs, req_pool_indices, seq_lens
|
||||
)
|
||||
|
||||
@@ -227,7 +234,7 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.zero_()
|
||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||
self.position_ids_offsets.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
@@ -239,7 +246,7 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.replay_cuda_graph_init(
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs, self.req_pool_indices, self.seq_lens
|
||||
)
|
||||
|
||||
|
||||
@@ -445,12 +445,6 @@ class ModelRunner:
|
||||
if self.server_args.disable_cuda_graph:
|
||||
return
|
||||
|
||||
if self.server_args.attention_backend != "flashinfer":
|
||||
logger.warning(
|
||||
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user