diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index abb7e5b..fd40d13 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -89,6 +89,7 @@ if TYPE_CHECKING: else: xgr = LazyLoader("xgr", globals(), "xgrammar") +import torch_npu import vllm.envs as envs_vllm import vllm_ascend.envs as envs_ascend @@ -96,6 +97,9 @@ import vllm_ascend.envs as envs_ascend if vllm_version_is("0.9.1"): from vllm.v1.spec_decode.utils import is_spec_decode_supported +if is_310p(): + torch_npu.npu.set_compile_mode(jit_compile=False) + @dataclass class GraphCaptureContext: @@ -2007,6 +2011,18 @@ class NPUModelRunner(LoRAModelRunnerMixin): with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + + if is_310p(): + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, + RowParallelLinear) + for module in self.model.modules(): + if isinstance(module, + (MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear)): + module.weight.data = torch_npu.npu_format_cast( + module.weight.data, ACL_FORMAT_FRACTAL_NZ) + try: # For version compatibility, remove this after we abort vllm v0.9.1 support from vllm.model_executor.models.interfaces import \