[Fix] Fix hicache backend (#8991)
This commit is contained in:
@@ -611,12 +611,7 @@ class Scheduler(
|
|||||||
hicache_ratio=server_args.hicache_ratio,
|
hicache_ratio=server_args.hicache_ratio,
|
||||||
hicache_size=server_args.hicache_size,
|
hicache_size=server_args.hicache_size,
|
||||||
hicache_write_policy=server_args.hicache_write_policy,
|
hicache_write_policy=server_args.hicache_write_policy,
|
||||||
hicache_io_backend=(
|
hicache_io_backend=server_args.hicache_io_backend,
|
||||||
"direct"
|
|
||||||
if server_args.attention_backend
|
|
||||||
== "fa3" # hot fix for incompatibility
|
|
||||||
else server_args.hicache_io_backend
|
|
||||||
),
|
|
||||||
hicache_mem_layout=server_args.hicache_mem_layout,
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||||
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
||||||
|
|||||||
@@ -403,7 +403,6 @@ class ModelRunner:
|
|||||||
is_hopper_with_cuda_12_3()
|
is_hopper_with_cuda_12_3()
|
||||||
and is_no_spec_infer_or_topk_one(server_args)
|
and is_no_spec_infer_or_topk_one(server_args)
|
||||||
and is_fa3_default_architecture(self.model_config.hf_config)
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
||||||
and (not server_args.enable_hierarchical_cache)
|
|
||||||
):
|
):
|
||||||
server_args.attention_backend = "fa3"
|
server_args.attention_backend = "fa3"
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
@@ -416,9 +415,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# MLA architecture
|
# MLA architecture
|
||||||
if is_hopper_with_cuda_12_3() and (
|
if is_hopper_with_cuda_12_3():
|
||||||
not server_args.enable_hierarchical_cache
|
|
||||||
):
|
|
||||||
server_args.attention_backend = "fa3"
|
server_args.attention_backend = "fa3"
|
||||||
elif is_sm100_supported():
|
elif is_sm100_supported():
|
||||||
server_args.attention_backend = "flashinfer"
|
server_args.attention_backend = "flashinfer"
|
||||||
@@ -506,6 +503,27 @@ class ModelRunner:
|
|||||||
if self.model_config.context_len > 8192:
|
if self.model_config.context_len > 8192:
|
||||||
self.mem_fraction_static *= 0.85
|
self.mem_fraction_static *= 0.85
|
||||||
|
|
||||||
|
if (
|
||||||
|
server_args.enable_hierarchical_cache
|
||||||
|
and server_args.hicache_io_backend == "kernel"
|
||||||
|
):
|
||||||
|
# fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
|
||||||
|
if server_args.decode_attention_backend is None:
|
||||||
|
if not self.use_mla_backend:
|
||||||
|
server_args.decode_attention_backend = (
|
||||||
|
"flashinfer" if is_flashinfer_available() else "triton"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
server_args.decode_attention_backend = (
|
||||||
|
"flashinfer" if is_sm100_supported() else "triton"
|
||||||
|
)
|
||||||
|
elif server_args.decode_attention_backend == "fa3":
|
||||||
|
server_args.hicache_io_backend = "direct"
|
||||||
|
logger.warning(
|
||||||
|
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
||||||
|
f"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
||||||
|
)
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
logger.info("Init torch distributed begin.")
|
logger.info("Init torch distributed begin.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user