Optimize Triton decoding kernel for dynamic workload (#4553)

This commit is contained in:
JieXin Liang
2025-03-19 12:25:38 +08:00
committed by GitHub
parent 588865f0e0
commit c0e9a36c5f
7 changed files with 277 additions and 57 deletions

View File

@@ -26,6 +26,7 @@ import tqdm
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -195,6 +196,9 @@ class CudaGraphRunner:
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -503,9 +507,15 @@ class CudaGraphRunner:
if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
num_kv_heads = self.num_head
if hasattr(forward_batch.token_to_kv_pool, "k_buffer"):
if isinstance(forward_batch.token_to_kv_pool.k_buffer, list):
num_kv_heads = forward_batch.token_to_kv_pool.k_buffer[0].shape[1]
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
num_kv_heads,
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),