Use seq_len_fill_value in the cuda graph runners (#7233)
This commit is contained in:
@@ -1807,7 +1807,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
return 0
|
||||
return 1
|
||||
|
||||
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
||||
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
||||
|
||||
@@ -440,7 +440,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
raise ValueError("Invalid forward mode")
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
|
||||
@@ -364,7 +364,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user