[CORE]initial support for torchair with non-mla backend (#1506)
### What this PR does / why we need it? This PR supports torchair graph mode with non-mla backend on both 800IA2 and 300I Duo platforms. The main change is to add `attention_v1_torchair.py` to support specific attention related operations that are required by torchair. ### Does this PR introduce _any_ user-facing change? Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we can also use it with pangu. Besides, we add a support model list to control which type of models that can use torchair. ### How was this patch tested? We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms, and model generates answer normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -2049,9 +2049,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
config.experimental_config.tiling_schedule_optimize = True
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
@@ -2149,27 +2159,50 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
if self.torchair_graph_enabled:
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||
layer_kv_cache_pe)
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][0],
|
||||
acl_format),
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name][1],
|
||||
acl_format),
|
||||
)
|
||||
if len(kv_cache_shape) == 3:
|
||||
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
|
||||
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
|
||||
|
||||
kv_caches[layer_name] = (
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device),
|
||||
torch.zeros(kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
# atb reshape_and_cache does not support torchair.
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][0],
|
||||
ACL_FORMAT_FRACTAL_ND),
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][1],
|
||||
ACL_FORMAT_FRACTAL_ND),
|
||||
)
|
||||
else:
|
||||
# for MLA attention backend that use torchair.
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.kv_lora_rank,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
layer_kv_cache_pe = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
(self.model_config.hf_text_config.
|
||||
qk_rope_head_dim, ),
|
||||
dtype=self.dtype,
|
||||
pin_memory=True,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = (layer_kv_cache_nope,
|
||||
layer_kv_cache_pe)
|
||||
kv_caches[layer_name] = (
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][0], acl_format),
|
||||
torch_npu.npu_format_cast(
|
||||
kv_caches[layer_name][1], acl_format),
|
||||
)
|
||||
else:
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape,
|
||||
|
||||
Reference in New Issue
Block a user