Minor style change of triton backend (#7165)
This commit is contained in:
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_num_kv_splits_triton(
|
||||
num_kv_splits_ptr,
|
||||
seq_lens_ptr,
|
||||
num_seq,
|
||||
num_group,
|
||||
num_head,
|
||||
num_kv_head,
|
||||
max_kv_splits,
|
||||
device_core_count,
|
||||
MAX_NUM_SEQ: tl.constexpr,
|
||||
):
|
||||
# TODO: this method is tunable, we need more online serving data to tune it
|
||||
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
||||
mask_seq = offs_seq < num_seq
|
||||
|
||||
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
||||
max_seq_len = tl.max(seq_lens)
|
||||
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
||||
min_seq_len = tl.min(seq_lens)
|
||||
if max_seq_len * 8 < min_seq_len * 10:
|
||||
min_seq_len = max_seq_len
|
||||
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
||||
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
||||
|
||||
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
||||
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
||||
ext_device_core_count = tl.cast(
|
||||
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
||||
)
|
||||
block_h, num_kv_group = 16, num_head // num_kv_head
|
||||
if num_kv_group == 1:
|
||||
token_grid = num_seq * num_group * num_head
|
||||
else:
|
||||
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
||||
block_h = tl.minimum(block_h, num_kv_group)
|
||||
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
||||
max_kv_splits_2 = tl.minimum(
|
||||
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
||||
)
|
||||
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
||||
|
||||
num_kv_splits = tl.maximum(
|
||||
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
||||
)
|
||||
|
||||
offs_token = offs_seq * num_group
|
||||
mask_token = offs_token < num_seq * num_group
|
||||
for i in range(0, num_group):
|
||||
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
||||
|
||||
|
||||
def update_sliding_window_buffer(
|
||||
window_kv_indptr,
|
||||
req_to_token,
|
||||
sliding_window_size,
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
device,
|
||||
):
|
||||
window_kv_lens = torch.minimum(
|
||||
seq_lens,
|
||||
torch.tensor(sliding_window_size + 1),
|
||||
)
|
||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||
window_kv_indices = torch.empty(
|
||||
window_kv_indptr[-1], dtype=torch.int32, device=device
|
||||
)
|
||||
window_kv_start_idx = seq_lens - window_kv_lens
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
window_kv_lens,
|
||||
window_kv_indptr,
|
||||
window_kv_start_idx,
|
||||
window_kv_indices,
|
||||
req_to_token.stride(0),
|
||||
)
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||
|
||||
|
||||
def update_sliding_window_buffer_cuda_graph(
|
||||
window_kv_indptr,
|
||||
window_kv_indices,
|
||||
req_to_token,
|
||||
sliding_window_size,
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
):
|
||||
window_kv_lens = torch.minimum(
|
||||
seq_lens,
|
||||
torch.tensor(sliding_window_size + 1),
|
||||
)
|
||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||
window_kv_start_idx = seq_lens - window_kv_lens
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
window_kv_lens,
|
||||
window_kv_indptr,
|
||||
window_kv_start_idx,
|
||||
window_kv_indices,
|
||||
req_to_token.stride(0),
|
||||
)
|
||||
return window_kv_indptr, window_kv_lens
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
attn_logits: torch.Tensor
|
||||
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.decode_attention_fwd = decode_attention_fwd
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
|
||||
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
|
||||
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_num_kv_splits_triton(
|
||||
num_kv_splits_ptr,
|
||||
seq_lens_ptr,
|
||||
num_seq,
|
||||
num_group,
|
||||
num_head,
|
||||
num_kv_head,
|
||||
max_kv_splits,
|
||||
device_core_count,
|
||||
MAX_NUM_SEQ: tl.constexpr,
|
||||
):
|
||||
# TODO: this method is tunable, we need more online serving data to tune it
|
||||
offs_seq = tl.arange(0, MAX_NUM_SEQ)
|
||||
mask_seq = offs_seq < num_seq
|
||||
|
||||
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=0)
|
||||
max_seq_len = tl.max(seq_lens)
|
||||
seq_lens = tl.load(seq_lens_ptr + offs_seq, mask=mask_seq, other=max_seq_len)
|
||||
min_seq_len = tl.min(seq_lens)
|
||||
if max_seq_len * 8 < min_seq_len * 10:
|
||||
min_seq_len = max_seq_len
|
||||
max_kv_splits_1 = tl.minimum(tl.cdiv(max_seq_len, min_seq_len), max_kv_splits)
|
||||
kv_chunk_size_1 = tl.cdiv(max_seq_len, max_kv_splits_1)
|
||||
|
||||
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
|
||||
ext_seq_len = tl.cast(max_seq_len, tl.float32) / 64.0
|
||||
ext_device_core_count = tl.cast(
|
||||
device_core_count * tl.maximum(tl.log2(ext_seq_len), 1.0), tl.int32
|
||||
)
|
||||
block_h, num_kv_group = 16, num_head // num_kv_head
|
||||
if num_kv_group == 1:
|
||||
token_grid = num_seq * num_group * num_head
|
||||
else:
|
||||
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
|
||||
block_h = tl.minimum(block_h, num_kv_group)
|
||||
token_grid = num_seq * num_group * tl.cdiv(num_head, block_h)
|
||||
max_kv_splits_2 = tl.minimum(
|
||||
tl.cdiv(ext_device_core_count, token_grid), max_kv_splits
|
||||
)
|
||||
kv_chunk_size_2 = tl.cdiv(max_seq_len, max_kv_splits_2)
|
||||
|
||||
num_kv_splits = tl.maximum(
|
||||
tl.cdiv(seq_lens, kv_chunk_size_1), tl.cdiv(seq_lens, kv_chunk_size_2)
|
||||
)
|
||||
|
||||
offs_token = offs_seq * num_group
|
||||
mask_token = offs_token < num_seq * num_group
|
||||
for i in range(0, num_group):
|
||||
tl.store(num_kv_splits_ptr + i + offs_token, num_kv_splits, mask=mask_token)
|
||||
|
||||
|
||||
def update_sliding_window_buffer(
|
||||
window_kv_indptr,
|
||||
req_to_token,
|
||||
sliding_window_size,
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
device,
|
||||
):
|
||||
window_kv_lens = torch.minimum(
|
||||
seq_lens,
|
||||
torch.tensor(sliding_window_size + 1),
|
||||
)
|
||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||
window_kv_indices = torch.empty(
|
||||
window_kv_indptr[-1], dtype=torch.int32, device=device
|
||||
)
|
||||
window_kv_start_idx = seq_lens - window_kv_lens
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
window_kv_lens,
|
||||
window_kv_indptr,
|
||||
window_kv_start_idx,
|
||||
window_kv_indices,
|
||||
req_to_token.stride(0),
|
||||
)
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||
|
||||
|
||||
def update_sliding_window_buffer_cuda_graph(
|
||||
window_kv_indptr,
|
||||
window_kv_indices,
|
||||
req_to_token,
|
||||
sliding_window_size,
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
):
|
||||
window_kv_lens = torch.minimum(
|
||||
seq_lens,
|
||||
torch.tensor(sliding_window_size + 1),
|
||||
)
|
||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||
window_kv_start_idx = seq_lens - window_kv_lens
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
window_kv_lens,
|
||||
window_kv_indptr,
|
||||
window_kv_start_idx,
|
||||
window_kv_indices,
|
||||
req_to_token.stride(0),
|
||||
)
|
||||
return window_kv_indptr, window_kv_lens
|
||||
|
||||
Reference in New Issue
Block a user