diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a102d0b1..cee31250 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2852,8 +2852,8 @@ class NPUModelRunner(GPUModelRunner): # a conv state in some special models. target_shape = (num_blocks, *shape) - target_idx += torch.prod(torch.tensor(target_shape)).item() - tensor = raw_tensor.view(dtype)[start_idx:target_idx].view(target_shape) + target_idx += math.prod(target_shape) * get_dtype_size(dtype) + tensor = raw_tensor[start_idx:target_idx].view(dtype).view(target_shape) start_idx = target_idx state_tensors.append(tensor) kv_caches[layer_name] = state_tensors