[Bugfix] Fix get_rope_shape for Kimi-K2.5 (#7521)
### What this PR does / why we need it?
Delete the logic that the input of get_rope_shape from device to host.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
@@ -39,20 +39,15 @@ class AscendLearnable2DInterpPosEmbDivided_fixed(nn.Module):
|
||||
def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
|
||||
pos_embs = []
|
||||
for t, h, w in grid_thws.tolist():
|
||||
x_device = x.device
|
||||
x_dtype = x.dtype
|
||||
assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}"
|
||||
if (h, w) == self.weight.shape[:-1]:
|
||||
pos_emb_2d = self.weight.flatten(end_dim=1)
|
||||
else:
|
||||
weight_fp32 = self.weight.to(dtype=torch.float32)
|
||||
weight_cpu = weight_fp32.to("cpu")
|
||||
pos_emb_2d = get_rope_shape(
|
||||
weight_cpu,
|
||||
self.weight,
|
||||
interpolation_mode=self.interpolation_mode,
|
||||
shape=(h, w),
|
||||
)
|
||||
pos_emb_2d = pos_emb_2d.to(x_device, dtype=x_dtype)
|
||||
|
||||
if t == 1:
|
||||
pos_emb_3d = pos_emb_2d
|
||||
|
||||
Reference in New Issue
Block a user