Support MHA with chunked prefix cache for flashinfer/flashmla backend, support page size > 1 for MHA chunked prefix (#8616)
Co-authored-by: xuyongfei.xyf <xuyongfei.xyf@antgroup.com>
This commit is contained in:
@@ -241,6 +241,9 @@ class ForwardBatch:
|
||||
prefix_chunk_num_tokens: Optional[List[int]] = None
|
||||
# KV Indices for each chunk
|
||||
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
||||
# For MLA chunked prefix cache used in chunked prefill
|
||||
# Tell attention backend whether lse needs to be returned
|
||||
mha_return_lse: Optional[bool] = None
|
||||
|
||||
# For multimodal
|
||||
mm_inputs: Optional[List[MultimodalInputs]] = None
|
||||
|
||||
@@ -518,9 +518,6 @@ class ModelRunner:
|
||||
|
||||
if not self.use_mla_backend:
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
elif self.page_size > 1:
|
||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
|
||||
if not server_args.disable_chunked_prefix_cache:
|
||||
logger.info("Chunked prefix cache is turned on.")
|
||||
|
||||
Reference in New Issue
Block a user