[Bugfix] Fix get_rope_shape for Kimi-K2.5 (#7521)
### What this PR does / why we need it?
Delete the logic that the input of get_rope_shape from device to host.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: LoganJane <loganJane73@hotmail.com>
This commit is contained in:
@@ -39,20 +39,15 @@ class AscendLearnable2DInterpPosEmbDivided_fixed(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
|
||||||
pos_embs = []
|
pos_embs = []
|
||||||
for t, h, w in grid_thws.tolist():
|
for t, h, w in grid_thws.tolist():
|
||||||
x_device = x.device
|
|
||||||
x_dtype = x.dtype
|
|
||||||
assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}"
|
assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}"
|
||||||
if (h, w) == self.weight.shape[:-1]:
|
if (h, w) == self.weight.shape[:-1]:
|
||||||
pos_emb_2d = self.weight.flatten(end_dim=1)
|
pos_emb_2d = self.weight.flatten(end_dim=1)
|
||||||
else:
|
else:
|
||||||
weight_fp32 = self.weight.to(dtype=torch.float32)
|
|
||||||
weight_cpu = weight_fp32.to("cpu")
|
|
||||||
pos_emb_2d = get_rope_shape(
|
pos_emb_2d = get_rope_shape(
|
||||||
weight_cpu,
|
self.weight,
|
||||||
interpolation_mode=self.interpolation_mode,
|
interpolation_mode=self.interpolation_mode,
|
||||||
shape=(h, w),
|
shape=(h, w),
|
||||||
)
|
)
|
||||||
pos_emb_2d = pos_emb_2d.to(x_device, dtype=x_dtype)
|
|
||||||
|
|
||||||
if t == 1:
|
if t == 1:
|
||||||
pos_emb_3d = pos_emb_2d
|
pos_emb_3d = pos_emb_2d
|
||||||
|
|||||||
Reference in New Issue
Block a user