Allow use of TRTLLM_MHA backend for hybrid attention on Blackwell (#11138)
This commit is contained in:
@@ -178,7 +178,8 @@ def attn_backend_wrapper(runner, full_attn_backend):
|
|||||||
if is_blackwell():
|
if is_blackwell():
|
||||||
assert (
|
assert (
|
||||||
runner.server_args.attention_backend == "triton"
|
runner.server_args.attention_backend == "triton"
|
||||||
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
|
or runner.server_args.attention_backend == "trtllm_mha"
|
||||||
|
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
|
||||||
if is_npu():
|
if is_npu():
|
||||||
assert (
|
assert (
|
||||||
runner.server_args.attention_backend == "ascend"
|
runner.server_args.attention_backend == "ascend"
|
||||||
|
|||||||
@@ -1620,7 +1620,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
elif self.is_hybrid_gdn:
|
elif self.is_hybrid_gdn:
|
||||||
self.token_to_kv_pool = HybridLinearKVPool(
|
self.token_to_kv_pool = HybridLinearKVPool(
|
||||||
page_size=self.page_size if _is_npu else 1,
|
page_size=self.page_size,
|
||||||
size=self.max_total_num_tokens,
|
size=self.max_total_num_tokens,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
head_num=self.model_config.get_num_kv_heads(
|
head_num=self.model_config.get_num_kv_heads(
|
||||||
|
|||||||
Reference in New Issue
Block a user