From 2d6113237407f9c447e07a23c82a8c9d803d9bf8 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 10 Feb 2025 20:00:42 +0800 Subject: [PATCH] Support Eagle2 for Triton backend (#3466) --- .../srt/layers/attention/triton_backend.py | 245 ++++++++++++++++-- .../attention/triton_ops/extend_attention.py | 8 +- python/sglang/srt/speculative/eagle_worker.py | 32 ++- test/srt/test_eagle_infer.py | 29 +++ test/srt/test_triton_attention_kernels.py | 12 +- 5 files changed, 285 insertions(+), 41 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 4da165486..e996cb159 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional import torch +import triton from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention.flashinfer_backend import ( @@ -18,7 +19,12 @@ if TYPE_CHECKING: class TritonAttnBackend(AttentionBackend): - def __init__(self, model_runner: ModelRunner): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): # Lazy import to avoid the initialization of cuda context from sglang.srt.layers.attention.triton_ops.decode_attention import ( decode_attention_fwd, @@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend): self.extend_attention_fwd = extend_attention_fwd max_bs = model_runner.req_to_token_pool.size - self.kv_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) + + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + self.kv_indptr = kv_indptr_buf + self.req_to_token = model_runner.req_to_token_pool.req_to_token self.qo_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) + self.mask_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int64, device=model_runner.device + ) + + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) @@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend): self.forward_metadata = None - self.cuda_graph_max_seq_len = model_runner.model_config.context_len + self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device @@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend): bs = forward_batch.batch_size kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info - if forward_batch.forward_mode.is_decode(): - attn_logits = torch.empty( + if forward_batch.forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.zeros( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + attn_logits = torch.zeros( ( - forward_batch.batch_size, + bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1, @@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend): device=self.device, ) + qo_indptr = None + custom_mask = None + mask_indptr = None max_extend_len = None - + elif forward_batch.forward_mode.is_target_verify(): + bs = len(forward_batch.req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + # Different with flashinfer kv_indptr and kv_indices construction kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + kv_indices = torch.zeros( + kv_indptr[-1], dtype=torch.int32, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, @@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend): self.req_to_token.stride(0), ) - qo_indptr = None - custom_mask = None - mask_offsets = None + custom_mask = spec_info.custom_mask + seq_mask_len = self.num_draft_tokens * ( + forward_batch.seq_lens + self.num_draft_tokens + ) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + max_extend_len = self.num_draft_tokens + attn_logits = None + elif forward_batch.forward_mode.is_draft_extend(): + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + self.req_to_token, + ) + ) + mask_indptr = None + max_extend_len = torch.max(spec_info.accept_length).item() + attn_logits = None else: kv_indptr[1 : bs + 1] = torch.cumsum( forward_batch.extend_prefix_lens, dim=0 ) kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( + kv_indices = torch.zeros( forward_batch.extend_prefix_lens.sum().item(), dtype=torch.int32, device=self.device, @@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend): qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) qo_indptr = qo_indptr[: bs + 1] custom_mask = None - mask_offsets = None - + mask_indptr = None attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() @@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend): kv_indices, qo_indptr, custom_mask, - mask_offsets, + mask_indptr, ) def init_cuda_graph_state(self, max_bs: int): - self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len + self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len self.cuda_graph_start_loc = torch.zeros( (max_bs,), dtype=torch.int32, device=self.device ) - self.cuda_graph_attn_logits = torch.empty( + self.cuda_graph_attn_logits = torch.zeros( (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), dtype=torch.float32, device=self.device, ) self.cuda_graph_kv_indices = torch.zeros( - (max_bs * self.cuda_graph_max_seq_len), + (max_bs * self.max_context_len), dtype=torch.int32, device=self.device, ) @@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend): kv_indices, qo_indptr, custom_mask, - mask_offsets, + mask_indptr, ) = self.forward_metadata + self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend): kv_indptr, kv_indices, custom_mask, - mask_offsets, + mask_indptr, max_extend_len, layer.scaling, layer.logit_cap, @@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend): layer.logit_cap, ) return o + + +class TritonMultiStepDraftBackend: + """ + Wrap multiple triton attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, + ) + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + TritonAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + ) + ) + self.max_context_len = self.attn_backends[0].max_context_len + # Cached variables for generate_draft_decode_kv_indices + self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + + def common_template( + self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int + ): + num_seqs = forward_batch.batch_size + bs = self.topk * num_seqs + seq_lens_sum = forward_batch.seq_lens_sum + + self.generate_draft_decode_kv_indices[ + (self.speculative_num_steps, num_seqs, self.topk) + ]( + forward_batch.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.seq_lens, + kv_indices_buffer, + self.kv_indptr, + forward_batch.positions, + num_seqs, + self.topk, + self.pool_len, + kv_indices_buffer.shape[1], + self.kv_indptr.shape[1], + triton.next_power_of_2(num_seqs), + triton.next_power_of_2(self.speculative_num_steps), + triton.next_power_of_2(bs), + ) + + for i in range(self.speculative_num_steps): + forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] + forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ + : seq_lens_sum * self.topk + bs * (i + 1) + ] + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + kv_indices = torch.zeros( + ( + self.speculative_num_steps, + forward_batch.batch_size * self.topk * self.max_context_len, + ), + dtype=torch.int32, + device="cuda", + ) + + def call_fn(i, forward_batch): + forward_batch.spec_info.kv_indptr = ( + forward_batch.spec_info.kv_indptr.clone() + ) + forward_batch.spec_info.kv_indices = ( + forward_batch.spec_info.kv_indices.clone() + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, kv_indices, call_fn) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indices = torch.zeros( + (self.speculative_num_steps, max_bs * self.max_context_len), + dtype=torch.int32, + device="cuda", + ) + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state( + max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + ) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + + def init_forward_metadata_replay_cuda_graph(self, forward_batch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + forward_batch.batch_size, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 9fe1e1b60..238351d29 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -50,7 +50,7 @@ def _fwd_kernel( kv_indptr, kv_indices, mask_ptr, - mask_offsets, + mask_indptr, sm_scale, kv_group_num, stride_qbs, @@ -87,7 +87,7 @@ def _fwd_kernel( cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend if USE_CUSTOM_MASK: - cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq) + cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq) offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) @@ -288,7 +288,7 @@ def extend_attention_fwd( kv_indptr, kv_indices, custom_mask, - mask_offsets, + mask_indptr, max_len_extend, sm_scale=None, logit_cap=0.0, @@ -364,7 +364,7 @@ def extend_attention_fwd( kv_indptr, kv_indices, custom_mask, - mask_offsets, + mask_indptr, sm_scale, kv_group_num, q_extend.stride(0), diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index c640be8c6..45798ba58 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker): self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph # Create multi-step attn backends and cuda graph runners - from sglang.srt.layers.attention.flashinfer_backend import ( - FlashInferMultiStepDraftBackend, - ) + if server_args.attention_backend == "flashinfer": + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferMultiStepDraftBackend, + ) + + self.draft_attn_backend = FlashInferMultiStepDraftBackend( + self.model_runner, + self.topk, + self.speculative_num_steps, + ) + elif server_args.attention_backend == "triton": + from sglang.srt.layers.attention.triton_backend import ( + TritonMultiStepDraftBackend, + ) + + self.draft_attn_backend = TritonMultiStepDraftBackend( + self.model_runner, + self.topk, + self.speculative_num_steps, + ) + else: + raise ValueError( + f"EAGLE is not supportted in attention backend {server_args.attention_backend}" + ) - self.draft_attn_backend = FlashInferMultiStepDraftBackend( - self.model_runner, - self.topk, - self.speculative_num_steps, - ) self.model_runner.draft_attn_backend = self.draft_attn_backend self.init_cuda_graphs() diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 4a6170320..225ccc9d6 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase): self.assertGreater(metrics["accuracy"], 0.20) +class TestEAGLEServerTriton(TestEAGLEServer): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "8", + "--speculative-num-draft-tokens", + "64", + "--mem-fraction-static", + "0.7", + "--attention-backend", + "triton", + # TODO: Support cuda graph + "--disable-cuda-graph", + ], + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 73e304fec..14372593d 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase): qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) custom_mask = None - mask_offsets = None + mask_indptr = None extend_attention_fwd( q_extend, @@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase): kv_indptr, kv_indices, custom_mask, - mask_offsets, + mask_indptr, max_len_extend, ) @@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase): custom_mask = torch.ones( (b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda" ) - mask_offsets = torch.zeros((B + 1,), dtype=torch.int64, device="cuda") - mask_offsets[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0) + mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda") + mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0) for i in range(B): causal_mask = ( torch.tril( @@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase): b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool ) mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten() - custom_mask[mask_offsets[i] : mask_offsets[i + 1]] = mask_flatten + custom_mask[mask_indptr[i] : mask_indptr[i + 1]] = mask_flatten extend_attention_fwd( q_extend, @@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase): kv_indptr, kv_indices, custom_mask, - mask_offsets, + mask_indptr, max_len_extend, )