Super tiny fix typo (#8046)
This commit is contained in:
@@ -1617,7 +1617,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_k + self.page_size - 1
|
metadata.max_seq_len_k + self.page_size - 1
|
||||||
) // self.page_size
|
) // self.page_size
|
||||||
|
|
||||||
normal_decode_set_medadata(
|
normal_decode_set_metadata(
|
||||||
metadata.cache_seqlens_int32,
|
metadata.cache_seqlens_int32,
|
||||||
metadata.cu_seqlens_k,
|
metadata.cu_seqlens_k,
|
||||||
metadata.page_table,
|
metadata.page_table,
|
||||||
@@ -1666,7 +1666,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||||
metadata.max_seq_len_k = max_len
|
metadata.max_seq_len_k = max_len
|
||||||
|
|
||||||
normal_decode_set_medadata(
|
normal_decode_set_metadata(
|
||||||
metadata.cache_seqlens_int32,
|
metadata.cache_seqlens_int32,
|
||||||
metadata.cu_seqlens_k,
|
metadata.cu_seqlens_k,
|
||||||
metadata.page_table,
|
metadata.page_table,
|
||||||
@@ -2089,7 +2089,7 @@ class FlashAttentionMultiStepBackend:
|
|||||||
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
# TODO: fuse these kernels
|
# TODO: fuse these kernels
|
||||||
# NOTE: torch.compile makes it slower in speculative decoding
|
# NOTE: torch.compile makes it slower in speculative decoding
|
||||||
def normal_decode_set_medadata(
|
def normal_decode_set_metadata(
|
||||||
cache_seqlens_int32: torch.Tensor,
|
cache_seqlens_int32: torch.Tensor,
|
||||||
cu_seqlens_k: torch.Tensor,
|
cu_seqlens_k: torch.Tensor,
|
||||||
page_table: torch.Tensor,
|
page_table: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user