Use seq_len_fill_value in the cuda graph runners (#7233)

This commit is contained in:
Lianmin Zheng
2025-06-16 15:57:07 -07:00
committed by GitHub
parent 8e2363dc15
commit c64290dcb5
7 changed files with 19 additions and 19 deletions

View File

@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 0
return 1
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""

View File

@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self):
return 0
return 1
def forward_extend(
self,

View File

@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}")
def get_cuda_graph_seq_len_fill_value(self):
return 0
return 1
def forward_extend(
self,