[Refactor] cleanup converting_weight_acl_format_format (#2482)
move maybe_converting_weight_acl_format_format to torchair module, it's
only used with 310p+torchair
- vLLM version: v0.10.1.1
- vLLM main:
49ab23b3cc
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -36,10 +36,11 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist,
|
||||
converting_weight_acl_format,
|
||||
register_torchair_model,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p, maybe_converting_weight_acl_format)
|
||||
is_310p)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
@@ -136,9 +137,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||
model_kwargs = {}
|
||||
@@ -152,6 +152,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors, inputs_embeds)
|
||||
@@ -261,9 +263,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
@@ -275,8 +276,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_ND)
|
||||
if is_310p():
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
@@ -5,6 +5,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
try:
|
||||
# Recent release of torchair has moved these ops to `.scope`.
|
||||
@@ -125,6 +126,25 @@ def npu_wait_tensor(self: torch.Tensor,
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
||||
|
||||
def converting_weight_acl_format(model, format):
|
||||
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
||||
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
||||
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
||||
# conversion when using torchair graph mode on 300I Duo platform.
|
||||
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
||||
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
module.w2_weight.data, format)
|
||||
|
||||
|
||||
def register_torchair_model():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
|
||||
Reference in New Issue
Block a user