[Bugfix] Disable torch.compile() (#370)
### What this PR does / why we need it? To resolve this [patch](https://github.com/vllm-project/vllm-ascend/pull/236/files#diff-43b96b39b5a52fe209d86449ad703a7ff5e1349ebaf1aa12ece8d82163ee5b61R24-R49) , we need to set `torch.compile()` backend to `eager` to disable compile, using default pytorch way. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -63,7 +63,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False):
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[ModelRunnerBase]] = None):
|
||||
# Register ops when worker init.
|
||||
from vllm_ascend import ops # noqa: F401
|
||||
|
||||
@@ -90,10 +91,10 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||||
model_config.hf_config.model_type) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
|
||||
@@ -107,6 +108,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
|
||||
Reference in New Issue
Block a user