diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 0765b673a..2b34a2965 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import ( _ColumnvLLMParameter, ) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.utils import pad_or_narrow_weight from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs if TYPE_CHECKING: @@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # bitsandbytes loads the weights of the specific portion # no need to narrow here if not use_bitsandbytes_4bit and not self.use_presharded_weights: - loaded_weight = loaded_weight.narrow( - output_dim, start_idx, shard_size - ) + # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned + end_idx = start_idx + shard_size + if end_idx > loaded_weight.shape[output_dim]: + loaded_weight = pad_or_narrow_weight( + loaded_weight, output_dim, start_idx, shard_size + ) + else: + loaded_weight = loaded_weight.narrow( + output_dim, start_idx, shard_size + ) # Special case for AQLM codebooks. elif is_metadata: @@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase): shard_size, ) else: - loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned + end_idx = start_idx + shard_size + if end_idx > loaded_weight.shape[input_dim]: + loaded_weight = pad_or_narrow_weight( + loaded_weight, input_dim, start_idx, shard_size + ) + else: + loaded_weight = loaded_weight.narrow( + input_dim, start_idx, shard_size + ) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 1ea75d70c..3cc1d2344 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -7,6 +7,7 @@ from typing import Callable, Optional, Union import torch from torch.nn import Parameter +from sglang.srt.layers.utils import pad_or_narrow_weight from sglang.srt.utils import is_cpu __all__ = [ @@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter): ) else: if not use_presharded_weights: - loaded_weight = loaded_weight.narrow( - self.output_dim, tp_rank * shard_size, shard_size - ) + # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned + start_idx = tp_rank * shard_size + end_idx = start_idx + shard_size + if end_idx > loaded_weight.shape[self.output_dim]: + loaded_weight = pad_or_narrow_weight( + loaded_weight, self.output_dim, start_idx, shard_size + ) + else: + loaded_weight = loaded_weight.narrow( + self.output_dim, start_idx, shard_size + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter): return else: - loaded_weight = loaded_weight.narrow( - self.input_dim, tp_rank * shard_size, shard_size - ) + # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned + start_idx = tp_rank * shard_size + end_idx = start_idx + shard_size + if end_idx > loaded_weight.shape[self.input_dim]: + loaded_weight = pad_or_narrow_weight( + loaded_weight, self.input_dim, start_idx, shard_size + ) + else: + loaded_weight = loaded_weight.narrow( + self.input_dim, start_idx, shard_size + ) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index cab505a50..17a79190d 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase): x.dtype, True, # is_vnni ) - x_q, x_scale = per_token_quant_int8(x) - return int8_scaled_mm( - x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias + x_q_2d = x_q.view(-1, x_q.shape[-1]) + x_scale_2d = x_scale.view(-1, x_scale.shape[-1]) + output_shape = [*x_q.shape[:-1], layer.weight.shape[1]] + + output = int8_scaled_mm( + x_q_2d, + layer.weight, + x_scale_2d, + layer.weight_scale, + out_dtype=x.dtype, + bias=bias, ) + return output.view(output_shape) + class W8A8Int8MoEMethod(FusedMoEMethodBase): """MoE method for INT8. diff --git a/python/sglang/srt/layers/utils.py b/python/sglang/srt/layers/utils.py index d79ccc663..45e154791 100644 --- a/python/sglang/srt/layers/utils.py +++ b/python/sglang/srt/layers/utils.py @@ -15,6 +15,29 @@ def get_layer_id(weight_name): return None +def pad_or_narrow_weight( + loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int +) -> torch.Tensor: + # Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned + valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0) + + if valid_size > 0: + loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size) + pad_shape = list(loaded_weight.shape) + pad_shape[input_dim] = shard_size - valid_size + pad = torch.zeros( + pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device + ) + return torch.cat([loaded_slice, pad], dim=input_dim) + + # All padding + pad_shape = list(loaded_weight.shape) + pad_shape[input_dim] = shard_size + return torch.zeros( + pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device + ) + + class PPMissingLayer(torch.nn.Identity): # Adapted from # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1 diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 9afb2b1ab..6c70629c2 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module): self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.window_size = vision_config.window_size self.patch_size = vision_config.patch_size - mlp_hidden_size: int = vision_config.intermediate_size + mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8 self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size,