Support cuda graph in the triton attention backend (#1401)
This commit is contained in:
@@ -36,14 +36,41 @@ class AttentionBackend(ABC):
|
|||||||
def init_forward_metadata(
|
def init_forward_metadata(
|
||||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||||
):
|
):
|
||||||
pass
|
"""Init the metadata for a forward pass."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(self, q, k, v, layer, input_metadata: InputMetadata):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
|
"""Init the global shared states for cuda graph."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
|
):
|
||||||
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
|
):
|
||||||
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
"""Run forward on an attention layer."""
|
||||||
if input_metadata.forward_mode.is_decode():
|
if input_metadata.forward_mode.is_decode():
|
||||||
return self.forward_decode(q, k, v, layer, input_metadata)
|
return self.forward_decode(q, k, v, layer, input_metadata)
|
||||||
else:
|
else:
|
||||||
return self.forward_extend(q, k, v, layer, input_metadata)
|
return self.forward_extend(q, k, v, layer, input_metadata)
|
||||||
|
|
||||||
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class FlashInferAttnBackend(AttentionBackend):
|
class FlashInferAttnBackend(AttentionBackend):
|
||||||
"""Flashinfer attention kernels."""
|
"""Flashinfer attention kernels."""
|
||||||
@@ -153,7 +180,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_kv_indices.clone(),
|
self.cuda_graph_kv_indices.clone(),
|
||||||
]
|
]
|
||||||
|
|
||||||
def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
|
):
|
||||||
if self.model_runner.sliding_window_size is None:
|
if self.model_runner.sliding_window_size is None:
|
||||||
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer,
|
self.workspace_buffer,
|
||||||
@@ -194,7 +223,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.forward_metadata = (False, None, decode_wrapper)
|
self.forward_metadata = (False, None, decode_wrapper)
|
||||||
|
|
||||||
def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens):
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
|
):
|
||||||
update_flashinfer_indices(
|
update_flashinfer_indices(
|
||||||
ForwardMode.DECODE,
|
ForwardMode.DECODE,
|
||||||
self.model_runner,
|
self.model_runner,
|
||||||
@@ -204,6 +235,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_metadata[bs],
|
self.cuda_graph_metadata[bs],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
if not isinstance(self.prefill_wrapper_paged, list):
|
if not isinstance(self.prefill_wrapper_paged, list):
|
||||||
prefill_wrapper_paged = self.prefill_wrapper_paged
|
prefill_wrapper_paged = self.prefill_wrapper_paged
|
||||||
@@ -290,6 +324,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
def __init__(self, model_runner: ModelRunner):
|
def __init__(self, model_runner: ModelRunner):
|
||||||
# Lazy import to avoid the initialization of cuda context
|
# Lazy import to avoid the initialization of cuda context
|
||||||
from sglang.srt.layers.triton_attention.decode_attention import (
|
from sglang.srt.layers.triton_attention.decode_attention import (
|
||||||
|
REDUCE_TORCH_TYPE,
|
||||||
decode_attention_fwd,
|
decode_attention_fwd,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.triton_attention.extend_attention import (
|
from sglang.srt.layers.triton_attention.extend_attention import (
|
||||||
@@ -300,29 +335,78 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.decode_attention_fwd = decode_attention_fwd
|
self.decode_attention_fwd = decode_attention_fwd
|
||||||
self.extend_attention_fwd = extend_attention_fwd
|
self.extend_attention_fwd = extend_attention_fwd
|
||||||
|
self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE
|
||||||
|
self.num_head = model_runner.model_config.num_attention_heads
|
||||||
|
|
||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
|
|
||||||
|
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||||
|
|
||||||
def init_forward_metadata(
|
def init_forward_metadata(
|
||||||
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
self, batch: ScheduleBatch, input_metadata: InputMetadata
|
||||||
):
|
):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
if input_metadata.forward_mode.is_decode():
|
if input_metadata.forward_mode.is_decode():
|
||||||
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
|
||||||
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
|
||||||
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
|
||||||
|
|
||||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||||
|
attn_logits = torch.empty(
|
||||||
|
(self.num_head, total_num_tokens),
|
||||||
|
dtype=self.REDUCE_TORCH_TYPE,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_seq_len = torch.max(input_metadata.seq_lens).item()
|
||||||
max_extend_len = None
|
max_extend_len = None
|
||||||
else:
|
else:
|
||||||
start_loc = max_seq_len = total_num_tokens = None
|
start_loc = attn_logits = max_seq_len = None
|
||||||
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
||||||
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
|
||||||
|
|
||||||
self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens
|
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
|
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
|
||||||
|
|
||||||
|
self.cuda_graph_start_loc = torch.zeros(
|
||||||
|
(max_bs,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
self.cuda_graph_attn_logits = torch.empty(
|
||||||
|
(self.num_head, self.cuda_graph_max_total_num_tokens),
|
||||||
|
dtype=self.REDUCE_TORCH_TYPE,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
|
):
|
||||||
|
self.forward_metadata = (
|
||||||
|
self.cuda_graph_start_loc,
|
||||||
|
self.cuda_graph_attn_logits,
|
||||||
|
self.cuda_graph_max_seq_len,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
|
):
|
||||||
|
self.cuda_graph_start_loc.zero_()
|
||||||
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||||
|
|
||||||
|
self.forward_metadata = (
|
||||||
|
self.cuda_graph_start_loc,
|
||||||
|
self.cuda_graph_attn_logits,
|
||||||
|
self.cuda_graph_max_seq_len,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
# TODO: reuse the buffer across layers
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
else:
|
else:
|
||||||
@@ -332,8 +416,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||||
|
|
||||||
self.extend_attention_fwd(
|
self.extend_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
@@ -350,16 +433,16 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||||
|
# TODO: reuse the buffer across layers
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
||||||
|
|
||||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||||
layer.layer_id, input_metadata.out_cache_loc, k, v
|
layer.layer_id, input_metadata.out_cache_loc, k, v
|
||||||
@@ -374,10 +457,9 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
input_metadata.req_pool_indices,
|
input_metadata.req_pool_indices,
|
||||||
start_loc,
|
start_loc,
|
||||||
input_metadata.seq_lens,
|
input_metadata.seq_lens,
|
||||||
|
attn_logits,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
total_num_tokens,
|
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
|||||||
@@ -66,18 +66,18 @@ class FlashinferUpdater:
|
|||||||
self.head_dim = model_runner.model_config.head_dim
|
self.head_dim = model_runner.model_config.head_dim
|
||||||
self.batch_size = len(req_pool_indices)
|
self.batch_size = len(req_pool_indices)
|
||||||
|
|
||||||
self.kv_last_page_len = torch.ones(
|
self.decode_wrapper = (
|
||||||
(self.batch_size,), dtype=torch.int32, device="cuda"
|
decode_wrapper or self.model_runner.attn_backend.decode_wrapper
|
||||||
|
)
|
||||||
|
self.prefill_wrapper_ragged = (
|
||||||
|
self.model_runner.attn_backend.prefill_wrapper_ragged
|
||||||
|
)
|
||||||
|
self.prefill_wrapper_paged = (
|
||||||
|
self.model_runner.attn_backend.prefill_wrapper_paged
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
self.kv_last_page_len = torch.ones(
|
||||||
self.decode_wrapper,
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
||||||
self.prefill_wrapper_ragged,
|
|
||||||
self.prefill_wrapper_paged,
|
|
||||||
) = (
|
|
||||||
decode_wrapper or self.model_runner.attn_backend.decode_wrapper,
|
|
||||||
self.model_runner.attn_backend.prefill_wrapper_ragged,
|
|
||||||
self.model_runner.attn_backend.prefill_wrapper_paged,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_indices_no_sliding_window(self):
|
def _init_indices_no_sliding_window(self):
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ def _fwd_kernel_stage1(
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel_stage2(
|
def _fwd_kernel_stage2(
|
||||||
Logics,
|
logits,
|
||||||
V_Buffer,
|
V_Buffer,
|
||||||
Out,
|
Out,
|
||||||
Req_to_tokens,
|
Req_to_tokens,
|
||||||
@@ -162,7 +162,7 @@ def _fwd_kernel_stage2(
|
|||||||
)
|
)
|
||||||
|
|
||||||
qk = tl.load(
|
qk = tl.load(
|
||||||
Logics
|
logits
|
||||||
+ cur_head * stride_logic_h
|
+ cur_head * stride_logic_h
|
||||||
+ (cur_batch_start_loc + start_n + offs_n),
|
+ (cur_batch_start_loc + start_n + offs_n),
|
||||||
mask=start_n + offs_n < cur_batch_seq_len,
|
mask=start_n + offs_n < cur_batch_seq_len,
|
||||||
@@ -238,7 +238,7 @@ def _decode_att_m_fwd(
|
|||||||
|
|
||||||
|
|
||||||
def _decode_softmax_reducev_fwd(
|
def _decode_softmax_reducev_fwd(
|
||||||
logics,
|
logits,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_tokens,
|
req_to_tokens,
|
||||||
@@ -247,9 +247,9 @@ def _decode_softmax_reducev_fwd(
|
|||||||
b_seq_len,
|
b_seq_len,
|
||||||
):
|
):
|
||||||
BLOCK = 64
|
BLOCK = 64
|
||||||
batch, head = b_seq_len.shape[0], logics.shape[0]
|
batch, head = b_seq_len.shape[0], logits.shape[0]
|
||||||
grid = (batch, head, 1)
|
grid = (batch, head, 1)
|
||||||
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
||||||
|
|
||||||
num_warps = 1
|
num_warps = 1
|
||||||
|
|
||||||
@@ -257,14 +257,14 @@ def _decode_softmax_reducev_fwd(
|
|||||||
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
_fwd_kernel_stage2[grid](
|
_fwd_kernel_stage2[grid](
|
||||||
logics,
|
logits,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_tokens,
|
req_to_tokens,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
logics.stride(0),
|
logits.stride(0),
|
||||||
v_buffer.stride(0),
|
v_buffer.stride(0),
|
||||||
v_buffer.stride(1),
|
v_buffer.stride(1),
|
||||||
o.stride(0),
|
o.stride(0),
|
||||||
@@ -387,7 +387,7 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_grouped_kernel_stage2(
|
def _fwd_grouped_kernel_stage2(
|
||||||
Logics,
|
logits,
|
||||||
V_Buffer,
|
V_Buffer,
|
||||||
Out,
|
Out,
|
||||||
Req_to_tokens,
|
Req_to_tokens,
|
||||||
@@ -443,7 +443,7 @@ def _fwd_grouped_kernel_stage2(
|
|||||||
)
|
)
|
||||||
|
|
||||||
qk = tl.load(
|
qk = tl.load(
|
||||||
Logics + offs_qk,
|
logits + offs_qk,
|
||||||
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
|
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
|
||||||
other=float("-inf"),
|
other=float("-inf"),
|
||||||
)
|
)
|
||||||
@@ -531,7 +531,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
|
|
||||||
|
|
||||||
def _decode_grouped_softmax_reducev_fwd(
|
def _decode_grouped_softmax_reducev_fwd(
|
||||||
logics,
|
logits,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_tokens,
|
req_to_tokens,
|
||||||
@@ -540,8 +540,8 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
b_seq_len,
|
b_seq_len,
|
||||||
):
|
):
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
batch, head_num = b_seq_len.shape[0], logics.shape[0]
|
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
||||||
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
||||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
||||||
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
||||||
|
|
||||||
@@ -551,14 +551,14 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
_fwd_grouped_kernel_stage2[grid](
|
_fwd_grouped_kernel_stage2[grid](
|
||||||
logics,
|
logits,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_tokens,
|
req_to_tokens,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
logics.stride(0),
|
logits.stride(0),
|
||||||
v_buffer.stride(0),
|
v_buffer.stride(0),
|
||||||
v_buffer.stride(1),
|
v_buffer.stride(1),
|
||||||
o.stride(0),
|
o.stride(0),
|
||||||
@@ -584,17 +584,11 @@ def decode_attention_fwd(
|
|||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
|
attn_logits,
|
||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
total_num_tokens,
|
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
att_m=None,
|
|
||||||
):
|
):
|
||||||
if att_m is None:
|
|
||||||
att_m = torch.empty(
|
|
||||||
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
||||||
|
|
||||||
if kv_group_num == 1:
|
if kv_group_num == 1:
|
||||||
@@ -602,7 +596,7 @@ def decode_attention_fwd(
|
|||||||
_decode_att_m_fwd(
|
_decode_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
att_m,
|
attn_logits,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
@@ -612,7 +606,7 @@ def decode_attention_fwd(
|
|||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_softmax_reducev_fwd(
|
_decode_softmax_reducev_fwd(
|
||||||
att_m,
|
attn_logits,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
@@ -625,7 +619,7 @@ def decode_attention_fwd(
|
|||||||
_decode_grouped_att_m_fwd(
|
_decode_grouped_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
att_m,
|
attn_logits,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
b_req_idx,
|
b_req_idx,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
@@ -635,7 +629,7 @@ def decode_attention_fwd(
|
|||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_decode_grouped_softmax_reducev_fwd(
|
_decode_grouped_softmax_reducev_fwd(
|
||||||
att_m,
|
attn_logits,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
req_to_token,
|
req_to_token,
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Copyright 2023-2024 SGLang Team
|
Copyright 2023-2024 SGLang Team
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -17,13 +19,12 @@ limitations under the License.
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
|
||||||
from sglang.srt.layers.logits_processor import (
|
from sglang.srt.layers.logits_processor import (
|
||||||
LogitsMetadata,
|
LogitsMetadata,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
@@ -35,6 +36,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
@@ -111,7 +115,7 @@ class CudaGraphRunner:
|
|||||||
self.req_pool_indices = torch.zeros(
|
self.req_pool_indices = torch.zeros(
|
||||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||||
self.position_ids_offsets = torch.ones(
|
self.position_ids_offsets = torch.ones(
|
||||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
@@ -121,6 +125,9 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||||
|
self.seq_len_fill_value = (
|
||||||
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
|
)
|
||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
vocab_size = model_runner.model_config.vocab_size
|
vocab_size = model_runner.model_config.vocab_size
|
||||||
@@ -176,7 +183,7 @@ class CudaGraphRunner:
|
|||||||
out_cache_loc = self.out_cache_loc[:bs]
|
out_cache_loc = self.out_cache_loc[:bs]
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.capture_cuda_graph_init(
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
bs, req_pool_indices, seq_lens
|
bs, req_pool_indices, seq_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -227,7 +234,7 @@ class CudaGraphRunner:
|
|||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens.zero_()
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.position_ids_offsets.fill_(1)
|
self.position_ids_offsets.fill_(1)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
|
|
||||||
@@ -239,7 +246,7 @@ class CudaGraphRunner:
|
|||||||
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.replay_cuda_graph_init(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs, self.req_pool_indices, self.seq_lens
|
bs, self.req_pool_indices, self.seq_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -445,12 +445,6 @@ class ModelRunner:
|
|||||||
if self.server_args.disable_cuda_graph:
|
if self.server_args.disable_cuda_graph:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.server_args.attention_backend != "flashinfer":
|
|
||||||
logger.warning(
|
|
||||||
f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
|
|
||||||
|
|||||||
@@ -96,6 +96,16 @@ class TestServingThroughput(unittest.TestCase):
|
|||||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||||
assert res["output_throughput"] > 2400
|
assert res["output_throughput"] > 2400
|
||||||
|
|
||||||
|
def test_default_with_triton_attention_backend(self):
|
||||||
|
res = self.run_test(
|
||||||
|
disable_radix_cache=ServerArgs.disable_radix_cache,
|
||||||
|
attention_backend="triton",
|
||||||
|
chunked_prefill_size=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||||
|
assert res["output_throughput"] > 2400
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user