diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index e49ba7f1f..29a9f37b0 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model +from sglang.srt.models.utils import permute_inv from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -405,6 +406,7 @@ class Qwen2_5_VisionTransformer(nn.Module): # Move window_index to the same device as x before using it to index x window_index = window_index.to(device=x.device) + reverse_indices = permute_inv(window_index) # Ensure rotary_pos_emb is on the same device/dtype as x rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype) @@ -451,8 +453,6 @@ class Qwen2_5_VisionTransformer(nn.Module): # adapter x = self.merger(x) - - reverse_indices = torch.argsort(window_index) x = x[reverse_indices, :] return x diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index 3adab87fe..8100e6734 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -53,3 +53,9 @@ def create_fused_set_kv_buffer_arg( v_scale=layer.v_scale, cache_loc=forward_batch.out_cache_loc, ) + + +def permute_inv(perm: torch.Tensor) -> torch.Tensor: + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype) + return inv_perm