diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 381ec4a1c..8fec69f12 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -394,7 +394,6 @@ class FlashAttentionBackend(AttentionBackend): dtype=torch.int32, ) metadata_expand.max_seq_len_q = 1 - metadata_expand.max_seq_len_k = self.speculative_step_id + 1 metadata_expand.cu_seqlens_q = torch.arange( 0, metadata_expand.cache_seqlens_int32.numel() + 1, @@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend): ), (1, 0), ) - metadata_expand.max_seq_len_k = ( - metadata_expand.cache_seqlens_int32.max().item() - ) self.forward_metadata_spec_decode_expand = metadata_expand elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) @@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend): ] ) metadata_expand.max_seq_len_q = 1 - metadata_expand.max_seq_len_k = ( - self.speculative_step_id + 1 - ) # , do this in replay metadata_expand.cu_seqlens_q = ( self.draft_decode_metadata_topk_expand["cu_seqlens_q"][ : bs * self.topk + 1 @@ -1766,9 +1759,6 @@ class FlashAttentionBackend(AttentionBackend): dtype=torch.int32, ) ) - metadata_expand.max_seq_len_k = ( - metadata_expand.cache_seqlens_int32.max().item() - ) elif forward_mode.is_draft_extend(): metadata = self.draft_extend_metadata[bs] metadata.cache_seqlens_int32.copy_(seq_lens)