[CORE]initial support for torchair with non-mla backend (#1506)

### What this PR does / why we need it?
This PR supports torchair graph mode with non-mla backend on both 800IA2
and 300I Duo platforms. The main change is to add
`attention_v1_torchair.py` to support specific attention related
operations that are required by torchair.

### Does this PR introduce _any_ user-facing change?
Before this PR, vLLM-Ascend only allows deepseek to use torchair. Now we
can also use it with pangu. Besides, we add a support model list to
control which type of models that can use torchair.

### How was this patch tested?
We have test it with PanguProMoE on both 800IA2 and 300I Duo platforms,
and model generates answer normally.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
Angazenn
2025-07-03 22:21:42 +08:00
committed by GitHub
parent 9fbd8017c0
commit a5f33590d3
19 changed files with 1130 additions and 84 deletions

View File

@@ -2049,9 +2049,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
from torchair import patch_for_hcom # type: ignore
patch_for_hcom()
if is_310p():
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
communication_adaptation_310p
communication_adaptation_310p()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
@@ -2149,27 +2159,50 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
if self.torchair_graph_enabled:
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank, ),
dtype=self.dtype,
pin_memory=True,
device=self.device)
layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.qk_rope_head_dim,
),
dtype=self.dtype,
pin_memory=True,
device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe)
kv_caches[layer_name] = (
torch_npu.npu_format_cast(kv_caches[layer_name][0],
acl_format),
torch_npu.npu_format_cast(kv_caches[layer_name][1],
acl_format),
)
if len(kv_cache_shape) == 3:
# for non MLA attention backend that use torchair, we consider to pass kv_cache layout
# of BSH ([num_blocks, block_size, kv_head_dim * head_size]) to attention.
kv_caches[layer_name] = (
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device),
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
# atb reshape_and_cache does not support torchair.
kv_caches[layer_name] = (
torch_npu.npu_format_cast(
kv_caches[layer_name][0],
ACL_FORMAT_FRACTAL_ND),
torch_npu.npu_format_cast(
kv_caches[layer_name][1],
ACL_FORMAT_FRACTAL_ND),
)
else:
# for MLA attention backend that use torchair.
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.kv_lora_rank,
),
dtype=self.dtype,
pin_memory=True,
device=self.device)
layer_kv_cache_pe = torch.zeros(
kv_cache_shape[:-1] +
(self.model_config.hf_text_config.
qk_rope_head_dim, ),
dtype=self.dtype,
pin_memory=True,
device=self.device)
kv_caches[layer_name] = (layer_kv_cache_nope,
layer_kv_cache_pe)
kv_caches[layer_name] = (
torch_npu.npu_format_cast(
kv_caches[layer_name][0], acl_format),
torch_npu.npu_format_cast(
kv_caches[layer_name][1], acl_format),
)
else:
kv_caches[layer_name] = torch.zeros(
kv_cache_shape,