fix: fp8 quantization failure of qwen 2.5 VL 7B model (#10112)
Signed-off-by: PanJason <pyyjason@gmail.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.utils import pad_or_narrow_weight
|
||||
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -625,9 +626,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
output_dim, start_idx, shard_size
|
||||
)
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[output_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, output_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
output_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase):
|
||||
shard_size,
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[input_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, input_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
input_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from sglang.srt.layers.utils import pad_or_narrow_weight
|
||||
from sglang.srt.utils import is_cpu
|
||||
|
||||
__all__ = [
|
||||
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
)
|
||||
else:
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[self.output_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, self.output_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
|
||||
|
||||
return
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||
start_idx = tp_rank * shard_size
|
||||
end_idx = start_idx + shard_size
|
||||
if end_idx > loaded_weight.shape[self.input_dim]:
|
||||
loaded_weight = pad_or_narrow_weight(
|
||||
loaded_weight, self.input_dim, start_idx, shard_size
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
x.dtype,
|
||||
True, # is_vnni
|
||||
)
|
||||
|
||||
x_q, x_scale = per_token_quant_int8(x)
|
||||
|
||||
return int8_scaled_mm(
|
||||
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
||||
x_q_2d = x_q.view(-1, x_q.shape[-1])
|
||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||
|
||||
output = int8_scaled_mm(
|
||||
x_q_2d,
|
||||
layer.weight,
|
||||
x_scale_2d,
|
||||
layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return output.view(output_shape)
|
||||
|
||||
|
||||
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for INT8.
|
||||
|
||||
@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
|
||||
return None
|
||||
|
||||
|
||||
def pad_or_narrow_weight(
|
||||
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
|
||||
) -> torch.Tensor:
|
||||
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
|
||||
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
|
||||
|
||||
if valid_size > 0:
|
||||
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
|
||||
pad_shape = list(loaded_weight.shape)
|
||||
pad_shape[input_dim] = shard_size - valid_size
|
||||
pad = torch.zeros(
|
||||
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
||||
)
|
||||
return torch.cat([loaded_slice, pad], dim=input_dim)
|
||||
|
||||
# All padding
|
||||
pad_shape = list(loaded_weight.shape)
|
||||
pad_shape[input_dim] = shard_size
|
||||
return torch.zeros(
|
||||
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
||||
)
|
||||
|
||||
|
||||
class PPMissingLayer(torch.nn.Identity):
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
||||
|
||||
@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
||||
self.window_size = vision_config.window_size
|
||||
self.patch_size = vision_config.patch_size
|
||||
mlp_hidden_size: int = vision_config.intermediate_size
|
||||
mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
|
||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
|
||||
Reference in New Issue
Block a user