Optimize Triton Draft Backend (#11556)
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.layers.radix_attention import AttentionType
|
from sglang.srt.layers.radix_attention import AttentionType
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_device_core_count,
|
get_device_core_count,
|
||||||
@@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
max_bs: int,
|
max_bs: int,
|
||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||||
|
cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.cuda_graph_attn_logits = torch.zeros(
|
self.cuda_graph_attn_logits = torch.zeros(
|
||||||
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
||||||
@@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.cuda_graph_num_kv_splits = torch.full(
|
|
||||||
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
if cuda_graph_num_kv_splits_buf is None:
|
||||||
)
|
self.cuda_graph_num_kv_splits = torch.full(
|
||||||
|
(max_num_tokens,),
|
||||||
|
self.max_kv_splits,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf
|
||||||
|
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
(max_num_tokens * self.max_context_len),
|
(max_num_tokens * self.max_context_len),
|
||||||
@@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
assert False, "Multi-step cuda graph init is not done here."
|
||||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
|
||||||
num_token = spec_info.kv_indptr.shape[0] - 1
|
|
||||||
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
|
self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs])
|
||||||
|
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
@@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend:
|
|||||||
topk: int,
|
topk: int,
|
||||||
speculative_num_steps: int,
|
speculative_num_steps: int,
|
||||||
):
|
):
|
||||||
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
|
||||||
|
|
||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
|
||||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(
|
(
|
||||||
@@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
self.attn_backends = []
|
self.attn_backends: List[TritonAttnBackend] = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
TritonAttnBackend(
|
TritonAttnBackend(
|
||||||
@@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend:
|
|||||||
self.page_size = model_runner.server_args.page_size
|
self.page_size = model_runner.server_args.page_size
|
||||||
|
|
||||||
def common_template(
|
def common_template(
|
||||||
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
kv_indices_buffer: Optional[torch.Tensor],
|
||||||
|
call_fn: int,
|
||||||
):
|
):
|
||||||
|
if kv_indices_buffer is None:
|
||||||
|
kv_indices_buffer = self.cuda_graph_kv_indices
|
||||||
|
|
||||||
num_seqs = forward_batch.batch_size
|
num_seqs = forward_batch.batch_size
|
||||||
bs = self.topk * num_seqs
|
bs = self.topk * num_seqs
|
||||||
seq_lens_sum = forward_batch.seq_lens_sum
|
seq_lens_sum = forward_batch.seq_lens_sum
|
||||||
|
|
||||||
self.generate_draft_decode_kv_indices[
|
generate_draft_decode_kv_indices[
|
||||||
(self.speculative_num_steps, num_seqs, self.topk)
|
(self.speculative_num_steps, num_seqs, self.topk)
|
||||||
](
|
](
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
@@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend:
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if call_fn is None:
|
||||||
|
return
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||||
@@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend:
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
self.cuda_graph_num_kv_splits = torch.full(
|
||||||
|
(max_num_tokens,),
|
||||||
|
self.attn_backends[0].max_kv_splits,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs,
|
||||||
|
max_num_tokens,
|
||||||
|
kv_indices_buf=self.cuda_graph_kv_indices[i],
|
||||||
|
cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
@@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend:
|
|||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
self.common_template(forward_batch, None, call_fn)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, forward_batch: ForwardBatch, bs: int
|
self, forward_batch: ForwardBatch, bs: int
|
||||||
):
|
):
|
||||||
def call_fn(i, forward_batch):
|
self.common_template(forward_batch, None, None)
|
||||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
|
||||||
bs,
|
|
||||||
forward_batch.req_pool_indices,
|
|
||||||
forward_batch.seq_lens,
|
|
||||||
seq_lens_sum=-1,
|
|
||||||
encoder_lens=None,
|
|
||||||
forward_mode=ForwardMode.DECODE,
|
|
||||||
spec_info=forward_batch.spec_info,
|
|
||||||
seq_lens_cpu=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
# NOTE: Multi-step's attention backends use the slice of
|
||||||
|
# - kv_indptr buffer (cuda graph and non-cuda graph)
|
||||||
|
# - kv_indices buffer (cuda graph only)
|
||||||
|
# So we don't need to assign the KV indices inside the attention backend.
|
||||||
|
|
||||||
|
# Compute num_kv_splits only once
|
||||||
|
num_token = forward_batch.batch_size * self.topk
|
||||||
|
self.attn_backends[-1].get_num_kv_splits(
|
||||||
|
self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token],
|
||||||
|
forward_batch.seq_lens[:bs],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
Reference in New Issue
Block a user