[Fix] Fix hicache backend (#8991)
This commit is contained in:
@@ -611,12 +611,7 @@ class Scheduler(
|
||||
hicache_ratio=server_args.hicache_ratio,
|
||||
hicache_size=server_args.hicache_size,
|
||||
hicache_write_policy=server_args.hicache_write_policy,
|
||||
hicache_io_backend=(
|
||||
"direct"
|
||||
if server_args.attention_backend
|
||||
== "fa3" # hot fix for incompatibility
|
||||
else server_args.hicache_io_backend
|
||||
),
|
||||
hicache_io_backend=server_args.hicache_io_backend,
|
||||
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
||||
|
||||
@@ -403,7 +403,6 @@ class ModelRunner:
|
||||
is_hopper_with_cuda_12_3()
|
||||
and is_no_spec_infer_or_topk_one(server_args)
|
||||
and is_fa3_default_architecture(self.model_config.hf_config)
|
||||
and (not server_args.enable_hierarchical_cache)
|
||||
):
|
||||
server_args.attention_backend = "fa3"
|
||||
elif _is_hip:
|
||||
@@ -416,9 +415,7 @@ class ModelRunner:
|
||||
)
|
||||
else:
|
||||
# MLA architecture
|
||||
if is_hopper_with_cuda_12_3() and (
|
||||
not server_args.enable_hierarchical_cache
|
||||
):
|
||||
if is_hopper_with_cuda_12_3():
|
||||
server_args.attention_backend = "fa3"
|
||||
elif is_sm100_supported():
|
||||
server_args.attention_backend = "flashinfer"
|
||||
@@ -506,6 +503,27 @@ class ModelRunner:
|
||||
if self.model_config.context_len > 8192:
|
||||
self.mem_fraction_static *= 0.85
|
||||
|
||||
if (
|
||||
server_args.enable_hierarchical_cache
|
||||
and server_args.hicache_io_backend == "kernel"
|
||||
):
|
||||
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
|
||||
if server_args.decode_attention_backend is None:
|
||||
if not self.use_mla_backend:
|
||||
server_args.decode_attention_backend = (
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
else:
|
||||
server_args.decode_attention_backend = (
|
||||
"flashinfer" if is_sm100_supported() else "triton"
|
||||
)
|
||||
elif server_args.decode_attention_backend == "fa3":
|
||||
server_args.hicache_io_backend = "direct"
|
||||
logger.warning(
|
||||
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
||||
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
||||
)
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user