[Bugfix] Disable torch.compile() (#370)

### What this PR does / why we need it?
To resolve this
[patch](https://github.com/vllm-project/vllm-ascend/pull/236/files#diff-43b96b39b5a52fe209d86449ad703a7ff5e1349ebaf1aa12ece8d82163ee5b61R24-R49)
, we need to set `torch.compile()` backend to `eager` to disable
compile, using default pytorch way.


---------

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2025-03-21 15:55:51 +08:00
committed by GitHub
parent 9a175ca0fc
commit 89ca63a2c2
4 changed files with 30 additions and 16 deletions

View File

@@ -89,6 +89,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
async_callback: Optional[Callable] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
previous_hidden_states: Optional[torch.Tensor] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
@@ -1082,6 +1083,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
@@ -1118,6 +1120,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
previous_hidden_states = kwargs.get("previous_hidden_states")
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch_npu.npu.Event(enable_timing=True)
@@ -1135,7 +1142,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
**seqlen_agnostic_kwargs,
**model_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):

View File

@@ -63,7 +63,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False):
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[ModelRunnerBase]] = None):
# Register ops when worker init.
from vllm_ascend import ops # noqa: F401
@@ -90,10 +91,10 @@ class NPUWorker(LocalOrDistributedWorkerBase):
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type ==
model_config.hf_config.model_type) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
@@ -107,6 +108,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker,
**speculative_args,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.