Optimize Triton Draft Backend (#11556)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_device_core_count,
|
||||
@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.cuda_graph_attn_logits = torch.zeros(
|
||||
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
||||
@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_num_kv_splits = torch.full(
|
||||
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
if cuda_graph_num_kv_splits_buf is None:
|
||||
self.cuda_graph_num_kv_splits = torch.full(
|
||||
(max_num_tokens,),
|
||||
self.max_kv_splits,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf
|
||||
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_num_tokens * self.max_context_len),
|
||||
@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
else:
|
||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||
num_token = spec_info.kv_indptr.shape[0] - 1
|
||||
assert False, "Multi-step cuda graph init is not done here."
|
||||
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
||||
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||
self.kv_indptr = torch.zeros(
|
||||
(
|
||||
@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.attn_backends = []
|
||||
self.attn_backends: List[TritonAttnBackend] = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
TritonAttnBackend(
|
||||
@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
|
||||
self.page_size = model_runner.server_args.page_size
|
||||
|
||||
def common_template(
|
||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
kv_indices_buffer: Optional[torch.Tensor],
|
||||
call_fn: int,
|
||||
):
|
||||
if kv_indices_buffer is None:
|
||||
kv_indices_buffer = self.cuda_graph_kv_indices
|
||||
|
||||
num_seqs = forward_batch.batch_size
|
||||
bs = self.topk * num_seqs
|
||||
seq_lens_sum = forward_batch.seq_lens_sum
|
||||
|
||||
self.generate_draft_decode_kv_indices[
|
||||
generate_draft_decode_kv_indices[
|
||||
(self.speculative_num_steps, num_seqs, self.topk)
|
||||
](
|
||||
forward_batch.req_pool_indices,
|
||||
@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if call_fn is None:
|
||||
return
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_num_kv_splits = torch.full(
|
||||
(max_num_tokens,),
|
||||
self.attn_backends[0].max_kv_splits,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
max_bs,
|
||||
max_num_tokens,
|
||||
kv_indices_buf=self.cuda_graph_kv_indices[i],
|
||||
cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
self.common_template(forward_batch, None, call_fn)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
def call_fn(i, forward_batch):
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
seq_lens_sum=-1,
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
self.common_template(forward_batch, None, None)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
# NOTE: Multi-step's attention backends use the slice of
|
||||
# - kv_indptr buffer (cuda graph and non-cuda graph)
|
||||
# - kv_indices buffer (cuda graph only)
|
||||
# So we don't need to assign the KV indices inside the attention backend.
|
||||
|
||||
# Compute num_kv_splits only once
|
||||
num_token = forward_batch.batch_size * self.topk
|
||||
self.attn_backends[-1].get_num_kv_splits(
|
||||
self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],
|
||||
forward_batch.seq_lens[:bs],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user