Support Qwen3-Next on Ascend NPU (#10379)
This commit is contained in:
@@ -649,6 +649,7 @@ class HybridLinearKVPool(KVCache):
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
page_size: int,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
full_attention_layer_ids: List[int],
|
||||
@@ -659,10 +660,14 @@ class HybridLinearKVPool(KVCache):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
self.page_size = 1
|
||||
self.page_size = page_size
|
||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||
assert not enable_kvcache_transpose
|
||||
self.full_kv_pool = MHATokenToKVPool(
|
||||
if _is_npu:
|
||||
TokenToKVPoolClass = AscendTokenToKVPool
|
||||
else:
|
||||
TokenToKVPoolClass = MHATokenToKVPool
|
||||
self.full_kv_pool = TokenToKVPoolClass(
|
||||
size=size,
|
||||
page_size=self.page_size,
|
||||
dtype=dtype,
|
||||
@@ -904,8 +909,12 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: Optional[float] = None,
|
||||
v_scale: Optional[float] = None,
|
||||
layer_id_override: Optional[int] = None,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if layer_id_override is not None:
|
||||
layer_id = layer_id_override
|
||||
else:
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
if k_scale is not None:
|
||||
cache_k.div_(k_scale)
|
||||
|
||||
Reference in New Issue
Block a user