[Bugfix] Disable torch.compile() (#370)
### What this PR does / why we need it? To resolve this [patch](https://github.com/vllm-project/vllm-ascend/pull/236/files#diff-43b96b39b5a52fe209d86449ad703a7ff5e1349ebaf1aa12ece8d82163ee5b61R24-R49) , we need to set `torch.compile()` backend to `eager` to disable compile, using default pytorch way. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -988,7 +988,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.v_head_dim,
|
self.v_head_dim,
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
device="npu")
|
device=query.device)
|
||||||
if (attn_metadata.block_tables is None
|
if (attn_metadata.block_tables is None
|
||||||
or attn_metadata.block_tables.numel() == 0):
|
or attn_metadata.block_tables.numel() == 0):
|
||||||
assert attn_metadata.attn_mask is not None
|
assert attn_metadata.attn_mask is not None
|
||||||
@@ -1015,11 +1015,12 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
|||||||
)
|
)
|
||||||
elif attn_metadata.decode_metadata:
|
elif attn_metadata.decode_metadata:
|
||||||
assert kv_cache is not None
|
assert kv_cache is not None
|
||||||
attn_output = torch.empty(num_tokens,
|
# if torch.empty is used here, the preemptive scheduling case of
|
||||||
self.num_heads,
|
# test_mtp_correctness.py will fail to run.
|
||||||
self.kv_lora_rank,
|
attn_output = torch.randn(
|
||||||
dtype=query.dtype,
|
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||||
device="npu")
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
self.seq_lens_tensor_cpu = torch.from_numpy(
|
self.seq_lens_tensor_cpu = torch.from_numpy(
|
||||||
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
||||||
np.int32))
|
np.int32))
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class NPUPlatform(Platform):
|
|||||||
_enum = PlatformEnum.OOT
|
_enum = PlatformEnum.OOT
|
||||||
device_name: str = "npu"
|
device_name: str = "npu"
|
||||||
device_type: str = "npu"
|
device_type: str = "npu"
|
||||||
simple_compile_backend: str = "npu"
|
simple_compile_backend: str = "eager" # Disable torch.compile()
|
||||||
ray_device_key: str = "NPU"
|
ray_device_key: str = "NPU"
|
||||||
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
|
||||||
dispatch_key: str = "PrivateUse1"
|
dispatch_key: str = "PrivateUse1"
|
||||||
@@ -99,11 +99,13 @@ class NPUPlatform(Platform):
|
|||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
|
||||||
|
elif vllm_config.speculative_config:
|
||||||
|
parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||||
|
parallel_config.sd_worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||||
|
elif vllm_config.scheduler_config.is_multi_step:
|
||||||
|
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
||||||
else:
|
else:
|
||||||
if vllm_config.scheduler_config.is_multi_step:
|
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
||||||
parallel_config.worker_cls = "vllm_ascend.worker.multi_step_worker.MultiStepWorker"
|
|
||||||
else:
|
|
||||||
parallel_config.worker_cls = "vllm_ascend.worker.worker.NPUWorker"
|
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
|||||||
async_callback: Optional[Callable] = None
|
async_callback: Optional[Callable] = None
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||||
|
previous_hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
@@ -1082,6 +1083,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
|
**kwargs,
|
||||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||||
if num_steps > 1:
|
if num_steps > 1:
|
||||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||||
@@ -1118,6 +1120,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
|||||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||||
} if self.has_inner_state else {}
|
} if self.has_inner_state else {}
|
||||||
|
|
||||||
|
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||||
|
model_kwargs = {}
|
||||||
|
if previous_hidden_states is not None:
|
||||||
|
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||||
|
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
model_forward_start = torch_npu.npu.Event(enable_timing=True)
|
model_forward_start = torch_npu.npu.Event(enable_timing=True)
|
||||||
@@ -1135,7 +1142,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
**seqlen_agnostic_kwargs)
|
**seqlen_agnostic_kwargs,
|
||||||
|
**model_kwargs)
|
||||||
|
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
|||||||
@@ -63,7 +63,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
local_rank: int,
|
local_rank: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
is_driver_worker: bool = False):
|
is_driver_worker: bool = False,
|
||||||
|
model_runner_cls: Optional[Type[ModelRunnerBase]] = None):
|
||||||
# Register ops when worker init.
|
# Register ops when worker init.
|
||||||
from vllm_ascend import ops # noqa: F401
|
from vllm_ascend import ops # noqa: F401
|
||||||
|
|
||||||
@@ -90,10 +91,10 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
speculative_config = self.speculative_config
|
speculative_config = self.speculative_config
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
speculative_args = {} if speculative_config is None \
|
speculative_args = {} if speculative_config is None \
|
||||||
or (speculative_config.draft_model_config.model ==
|
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||||||
model_config.model) \
|
model_config.hf_config.model_type) \
|
||||||
or (speculative_config.draft_model_config.hf_config.model_type
|
or (speculative_config.draft_model_config.hf_config.model_type
|
||||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \
|
||||||
else {"return_hidden_states": True}
|
else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
|
ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner
|
||||||
@@ -107,6 +108,8 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
**speculative_args,
|
**speculative_args,
|
||||||
)
|
)
|
||||||
|
if model_runner_cls is not None:
|
||||||
|
self.model_runner = model_runner_cls(self.model_runner)
|
||||||
|
|
||||||
# Uninitialized cache engine. Will be initialized by
|
# Uninitialized cache engine. Will be initialized by
|
||||||
# initialize_cache.
|
# initialize_cache.
|
||||||
|
|||||||
Reference in New Issue
Block a user