diff --git a/vllm_ascend/patch/worker/patch_kimi_k25.py b/vllm_ascend/patch/worker/patch_kimi_k25.py index d7646937..45531dd4 100644 --- a/vllm_ascend/patch/worker/patch_kimi_k25.py +++ b/vllm_ascend/patch/worker/patch_kimi_k25.py @@ -39,20 +39,15 @@ class AscendLearnable2DInterpPosEmbDivided_fixed(nn.Module): def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor: pos_embs = [] for t, h, w in grid_thws.tolist(): - x_device = x.device - x_dtype = x.dtype assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}" if (h, w) == self.weight.shape[:-1]: pos_emb_2d = self.weight.flatten(end_dim=1) else: - weight_fp32 = self.weight.to(dtype=torch.float32) - weight_cpu = weight_fp32.to("cpu") pos_emb_2d = get_rope_shape( - weight_cpu, + self.weight, interpolation_mode=self.interpolation_mode, shape=(h, w), ) - pos_emb_2d = pos_emb_2d.to(x_device, dtype=x_dtype) if t == 1: pos_emb_3d = pos_emb_2d