diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index e996cb159..627a1db23 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd + self.skip_prefill = skip_prefill + max_bs = model_runner.req_to_token_pool.size if kv_indptr_buf is None: @@ -48,13 +50,15 @@ class TritonAttnBackend(AttentionBackend): self.kv_indptr = kv_indptr_buf self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.qo_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int32, device=model_runner.device - ) - self.mask_indptr = torch.zeros( - (max_bs + 1,), dtype=torch.int64, device=model_runner.device - ) + if not self.skip_prefill: + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + + self.mask_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int64, device=model_runner.device + ) self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens @@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend): mask_indptr, ) - def init_cuda_graph_state(self, max_bs: int): - self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len - - self.cuda_graph_start_loc = torch.zeros( - (max_bs,), dtype=torch.int32, device=self.device - ) + def init_cuda_graph_state( + self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + ): self.cuda_graph_attn_logits = torch.zeros( (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), dtype=torch.float32, device=self.device, ) - self.cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.int32, - device=self.device, - ) + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device=self.device, + ) def init_forward_metadata_capture_cuda_graph( self, @@ -224,31 +235,71 @@ class TritonAttnBackend(AttentionBackend): spec_info: Optional[SpecInfo], ): assert encoder_lens is None, "Not supported" - assert forward_mode.is_decode(), "Not supported" - assert spec_info is None, "Not supported" - kv_indptr = self.kv_indptr - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - seq_lens, - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) + if forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + + attn_logits = self.cuda_graph_attn_logits + max_extend_len = None + qo_indptr = None + custom_mask = None + mask_indptr = None + elif forward_mode.is_target_verify(): + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + custom_mask = self.cuda_graph_custom_mask + seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) + mask_indptr = self.mask_indptr[: bs + 1] + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) + max_extend_len = self.num_draft_tokens + attn_logits = None + else: + raise ValueError( + f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." + ) self.forward_metadata = ( - self.cuda_graph_attn_logits, - None, + attn_logits, + max_extend_len, kv_indptr, kv_indices, - None, - None, - None, + qo_indptr, + custom_mask, + mask_indptr, ) def init_forward_metadata_replay_cuda_graph( @@ -262,22 +313,57 @@ class TritonAttnBackend(AttentionBackend): spec_info: Optional[SpecInfo], ): # NOTE: encoder_lens expected to be zeros or None - self.cuda_graph_start_loc.zero_() - self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) - - kv_indptr = self.kv_indptr - kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = self.cuda_graph_kv_indices - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices[:bs], - seq_lens[:bs], - kv_indptr, - None, - kv_indices, - self.req_to_token.stride(0), - ) + if forward_mode.is_decode_or_idle(): + # Update kv_indptr, kv_indices + kv_indptr = self.kv_indptr + kv_indices = self.cuda_graph_kv_indices + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr + kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + elif forward_mode.is_target_verify(): + # Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr + bs = len(req_pool_indices) + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + custom_mask = self.cuda_graph_custom_mask + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) + mask_indptr = self.mask_indptr[: bs + 1] + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) + else: + raise ValueError( + f"Invalid forward mode: {forward_mode=} for CUDA Graph replay." + ) def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend: ) ) self.max_context_len = self.attn_backends[0].max_context_len + self.device = model_runner.device # Cached variables for generate_draft_decode_kv_indices self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] @@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend: forward_batch.batch_size * self.topk * self.max_context_len, ), dtype=torch.int32, - device="cuda", + device=self.device, ) def call_fn(i, forward_batch): @@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend: self.cuda_graph_kv_indices = torch.zeros( (self.speculative_num_steps, max_bs * self.max_context_len), dtype=torch.int32, - device="cuda", + device=self.device, ) for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state( diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 225ccc9d6..5e627fd11 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -216,8 +216,6 @@ class TestEAGLEServerTriton(TestEAGLEServer): "0.7", "--attention-backend", "triton", - # TODO: Support cuda graph - "--disable-cuda-graph", ], )