From b2e71b79301cdcd5c230e73cadb7523ddc76fde9 Mon Sep 17 00:00:00 2001 From: LoganJane <42287016+LoganJane@users.noreply.github.com> Date: Sun, 22 Mar 2026 21:06:31 +0800 Subject: [PATCH] [Bugfix] Fix get_rope_shape for Kimi-K2.5 (#7521) ### What this PR does / why we need it? Delete the logic that the input of get_rope_shape from device to host. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c Signed-off-by: LoganJane --- vllm_ascend/patch/worker/patch_kimi_k25.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_kimi_k25.py b/vllm_ascend/patch/worker/patch_kimi_k25.py index d7646937..45531dd4 100644 --- a/vllm_ascend/patch/worker/patch_kimi_k25.py +++ b/vllm_ascend/patch/worker/patch_kimi_k25.py @@ -39,20 +39,15 @@ class AscendLearnable2DInterpPosEmbDivided_fixed(nn.Module): def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor: pos_embs = [] for t, h, w in grid_thws.tolist(): - x_device = x.device - x_dtype = x.dtype assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}" if (h, w) == self.weight.shape[:-1]: pos_emb_2d = self.weight.flatten(end_dim=1) else: - weight_fp32 = self.weight.to(dtype=torch.float32) - weight_cpu = weight_fp32.to("cpu") pos_emb_2d = get_rope_shape( - weight_cpu, + self.weight, interpolation_mode=self.interpolation_mode, shape=(h, w), ) - pos_emb_2d = pos_emb_2d.to(x_device, dtype=x_dtype) if t == 1: pos_emb_3d = pos_emb_2d