Support Qwen3-Next on Ascend NPU (#10379)

This commit is contained in:
Even Zhou
2025-09-13 07:31:37 +08:00
committed by GitHub
parent d5e2a37414
commit 16cd550c85
10 changed files with 79 additions and 26 deletions

View File

@@ -649,6 +649,7 @@ class HybridLinearKVPool(KVCache):
self,
size: int,
dtype: torch.dtype,
page_size: int,
head_num: int,
head_dim: int,
full_attention_layer_ids: List[int],
@@ -659,10 +660,14 @@ class HybridLinearKVPool(KVCache):
self.dtype = dtype
self.device = device
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1
self.page_size = page_size
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
self.full_kv_pool = MHATokenToKVPool(
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
@@ -904,8 +909,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None,
):
layer_id = layer.layer_id
if layer_id_override is not None:
layer_id = layer_id_override
else:
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)