diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1dad4a28..80299a05 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -31,6 +31,7 @@ import torch_npu # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event from vllm.logger import logger +from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -844,6 +845,13 @@ def weak_ref_tensors( return [weak_ref_tensor(t) for t in tensors] if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) + # For IntermediateTensors used in pipeline parallelism + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors({ + key: weak_ref_tensor(val) + for key, val in tensors.tensors.items() + }) + return ret raise ValueError("Invalid type for tensors")