[Bugfix] Fix get_rope_shape for Kimi-K2.5 (#7521)

### What this PR does / why we need it?
Delete the logic that the input of get_rope_shape from device to host.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c

Signed-off-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
LoganJane
2026-03-22 21:06:31 +08:00
committed by GitHub
parent 9e2965bae2
commit b2e71b7930

View File

@@ -39,20 +39,15 @@ class AscendLearnable2DInterpPosEmbDivided_fixed(nn.Module):
def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
pos_embs = []
for t, h, w in grid_thws.tolist():
x_device = x.device
x_dtype = x.dtype
assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}"
if (h, w) == self.weight.shape[:-1]:
pos_emb_2d = self.weight.flatten(end_dim=1)
else:
weight_fp32 = self.weight.to(dtype=torch.float32)
weight_cpu = weight_fp32.to("cpu")
pos_emb_2d = get_rope_shape(
weight_cpu,
self.weight,
interpolation_mode=self.interpolation_mode,
shape=(h, w),
)
pos_emb_2d = pos_emb_2d.to(x_device, dtype=x_dtype)
if t == 1:
pos_emb_3d = pos_emb_2d