[Bugfix] Disable torch.compile() (#370)
### What this PR does / why we need it? To resolve this [patch](https://github.com/vllm-project/vllm-ascend/pull/236/files#diff-43b96b39b5a52fe209d86449ad703a7ff5e1349ebaf1aa12ece8d82163ee5b61R24-R49) , we need to set `torch.compile()` backend to `eager` to disable compile, using default pytorch way. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -89,6 +89,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
||||
async_callback: Optional[Callable] = None
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None
|
||||
previous_hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
@@ -1082,6 +1083,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
@@ -1118,6 +1120,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
|
||||
previous_hidden_states = kwargs.get("previous_hidden_states")
|
||||
model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
model_kwargs["previous_hidden_states"] = previous_hidden_states
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch_npu.npu.Event(enable_timing=True)
|
||||
@@ -1135,7 +1142,8 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
**seqlen_agnostic_kwargs,
|
||||
**model_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
|
||||
Reference in New Issue
Block a user