From 89ca63a2c2d98dbd153b33596888090426a9e9f0 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Fri, 21 Mar 2025 15:55:51 +0800 Subject: [PATCH] [Bugfix] Disable torch.compile() (#370) ### What this PR does / why we need it? To resolve this [patch](https://github.com/vllm-project/vllm-ascend/pull/236/files#diff-43b96b39b5a52fe209d86449ad703a7ff5e1349ebaf1aa12ece8d82163ee5b61R24-R49) , we need to set `torch.compile()` backend to `eager` to disable compile, using default pytorch way. --------- Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/attention/attention.py | 13 +++++++------ vllm_ascend/platform.py | 12 +++++++----- vllm_ascend/worker/model_runner.py | 10 +++++++++- vllm_ascend/worker/worker.py | 11 +++++++---- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 5771a11..587d443 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -988,7 +988,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): self.num_heads, self.v_head_dim, dtype=query.dtype, - device="npu") + device=query.device) if (attn_metadata.block_tables is None or attn_metadata.block_tables.numel() == 0): assert attn_metadata.attn_mask is not None @@ -1015,11 +1015,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl): ) elif attn_metadata.decode_metadata: assert kv_cache is not None - attn_output = torch.empty(num_tokens, - self.num_heads, - self.kv_lora_rank, - dtype=query.dtype, - device="npu") + # if torch.empty is used here, the preemptive scheduling case of + # test_mtp_correctness.py will fail to run. + attn_output = torch.randn( + [num_tokens, self.num_heads, self.kv_lora_rank], + dtype=query.dtype, + device=query.device) self.seq_lens_tensor_cpu = torch.from_numpy( np.array(attn_metadata.decode_metadata.seq_lens).astype( np.int32)) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index bddac5e..5a69df0 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -40,7 +40,7 @@ class NPUPlatform(Platform): _enum = PlatformEnum.OOT device_name: str = "npu" device_type: str = "npu" - simple_compile_backend: str = "npu" + simple_compile_backend: str = "eager" # Disable torch.compile() ray_device_key: str = "NPU" device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" dispatch_key: str = "PrivateUse1" @@ -99,11 +99,13 @@ class NPUPlatform(Platform): if parallel_config.worker_cls == "auto": if envs.VLLM_USE_V1: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" + elif vllm_config.speculative_config: + parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = "vllm_ascend.worker.worker.NPUWorker" + elif vllm_config.scheduler_config.is_multi_step: + parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker" else: - if vllm_config.scheduler_config.is_multi_step: - parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker" - else: - parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" + parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker" cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index c30831f..8a21792 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -89,6 +89,7 @@ class ModelInputForNPU(ModelRunnerInputBase): async_callback: Optional[Callable] = None seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None scheduler_outputs: Optional[SchedulerOutputs] = None + previous_hidden_states: Optional[torch.Tensor] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1082,6 +1083,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + **kwargs, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") @@ -1118,6 +1120,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} + previous_hidden_states = kwargs.get("previous_hidden_states") + model_kwargs = {} + if previous_hidden_states is not None: + model_kwargs["previous_hidden_states"] = previous_hidden_states + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch_npu.npu.Event(enable_timing=True) @@ -1135,7 +1142,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), - **seqlen_agnostic_kwargs) + **seqlen_agnostic_kwargs, + **model_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 0f7256b..736c138 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -63,7 +63,8 @@ class NPUWorker(LocalOrDistributedWorkerBase): local_rank: int, rank: int, distributed_init_method: str, - is_driver_worker: bool = False): + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[ModelRunnerBase]] = None): # Register ops when worker init. from vllm_ascend import ops # noqa: F401 @@ -90,10 +91,10 @@ class NPUWorker(LocalOrDistributedWorkerBase): speculative_config = self.speculative_config model_config = self.model_config speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type == + model_config.hf_config.model_type) \ or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ + not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \ else {"return_hidden_states": True} ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner @@ -107,6 +108,8 @@ class NPUWorker(LocalOrDistributedWorkerBase): is_driver_worker=is_driver_worker, **speculative_args, ) + if model_runner_cls is not None: + self.model_runner = model_runner_cls(self.model_runner) # Uninitialized cache engine. Will be initialized by # initialize_cache.