diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 4dd62ac..b3b8ecb 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -20,10 +20,11 @@ from typing import Optional import torch +import torch_npu from vllm.config import VllmConfig from vllm.forward_context import get_forward_context -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, maybe_converting_weight_acl_format) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -113,3 +114,7 @@ class NPUTorchairModelRunner(NPUModelRunner): with_prefill, is_torchair_compile, input_ids, positions, attn_metadata, num_tokens, intermediate_tensors, inputs_embeds) return hidden_states + + def _convert_torch_format(self, kv_cache): + kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND) + return kv_cache diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a0fe9e0..ae1cff3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -110,6 +110,9 @@ import vllm_ascend.envs as envs_ascend if is_310p(): torch_npu.npu.set_compile_mode(jit_compile=False) + ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ +else: + ACL_FORMAT = ACL_FORMAT_FRACTAL_ND @dataclass @@ -2047,8 +2050,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): if isinstance(module, (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)): - module.weight.data = torch_npu.npu_format_cast( - module.weight.data, ACL_FORMAT_FRACTAL_NZ) + module.weight.data = self._convert_torch_format( + module.weight.data) if self.drafter: logger.info("Loading drafter model...") if isinstance(self.drafter, EagleProposer): @@ -2133,6 +2136,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): ge_cache=False) return self.torchair_compiled_models[batch_size] + def _convert_torch_format(self, tensor): + tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) + return tensor + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -2141,9 +2148,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): cache size of each layer """ self.kv_cache_config = kv_cache_config - import torch_npu - acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p( - ) and not self.torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND kv_caches: Dict[str, torch.Tensor] = {} def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: @@ -2202,7 +2206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_cache_spec.head_size) dtype = kv_cache_spec.dtype if self.model_config.is_deepseek_mla: - num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape rope_dim = self.model_config.hf_text_config.qk_rope_head_dim nope_dim = head_size - rope_dim @@ -2218,10 +2221,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): nope_cache = torch.zeros(nope_cache_shape, dtype=dtype, device=self.device) - rope_cache = torch_npu.npu_format_cast( - rope_cache, acl_format) - nope_cache = torch_npu.npu_format_cast( - nope_cache, acl_format) + rope_cache = self._convert_torch_format(rope_cache) + nope_cache = self._convert_torch_format(nope_cache) else: # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory @@ -2259,8 +2260,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_cache = torch.zeros(cache_shape, dtype=dtype, device=self.device) - kv_cache = torch_npu.npu_format_cast( - kv_cache, acl_format) + kv_cache = self._convert_torch_format(kv_cache) else: cache_size = math.prod(cache_shape) cache_size_aligned = cache_size + alignment