fix: fp8 quantization failure of qwen 2.5 VL 7B model (#10112)
Signed-off-by: PanJason <pyyjason@gmail.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from sglang.srt.layers.parameter import (
|
|||||||
_ColumnvLLMParameter,
|
_ColumnvLLMParameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.utils import pad_or_narrow_weight
|
||||||
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -625,6 +626,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
# bitsandbytes loads the weights of the specific portion
|
# bitsandbytes loads the weights of the specific portion
|
||||||
# no need to narrow here
|
# no need to narrow here
|
||||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||||
|
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||||
|
end_idx = start_idx + shard_size
|
||||||
|
if end_idx > loaded_weight.shape[output_dim]:
|
||||||
|
loaded_weight = pad_or_narrow_weight(
|
||||||
|
loaded_weight, output_dim, start_idx, shard_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
loaded_weight = loaded_weight.narrow(
|
loaded_weight = loaded_weight.narrow(
|
||||||
output_dim, start_idx, shard_size
|
output_dim, start_idx, shard_size
|
||||||
)
|
)
|
||||||
@@ -1302,7 +1310,16 @@ class RowParallelLinear(LinearBase):
|
|||||||
shard_size,
|
shard_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||||
|
end_idx = start_idx + shard_size
|
||||||
|
if end_idx > loaded_weight.shape[input_dim]:
|
||||||
|
loaded_weight = pad_or_narrow_weight(
|
||||||
|
loaded_weight, input_dim, start_idx, shard_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.narrow(
|
||||||
|
input_dim, start_idx, shard_size
|
||||||
|
)
|
||||||
|
|
||||||
# Special case for loading scales off disk, which often do not
|
# Special case for loading scales off disk, which often do not
|
||||||
# have a shape (such as in the case of AutoFP8).
|
# have a shape (such as in the case of AutoFP8).
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from sglang.srt.layers.utils import pad_or_narrow_weight
|
||||||
from sglang.srt.utils import is_cpu
|
from sglang.srt.utils import is_cpu
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -156,8 +157,16 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not use_presharded_weights:
|
if not use_presharded_weights:
|
||||||
|
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||||
|
start_idx = tp_rank * shard_size
|
||||||
|
end_idx = start_idx + shard_size
|
||||||
|
if end_idx > loaded_weight.shape[self.output_dim]:
|
||||||
|
loaded_weight = pad_or_narrow_weight(
|
||||||
|
loaded_weight, self.output_dim, start_idx, shard_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
loaded_weight = loaded_weight.narrow(
|
loaded_weight = loaded_weight.narrow(
|
||||||
self.output_dim, tp_rank * shard_size, shard_size
|
self.output_dim, start_idx, shard_size
|
||||||
)
|
)
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
@@ -257,9 +266,17 @@ class RowvLLMParameter(BasevLLMParameter):
|
|||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
||||||
|
start_idx = tp_rank * shard_size
|
||||||
|
end_idx = start_idx + shard_size
|
||||||
|
if end_idx > loaded_weight.shape[self.input_dim]:
|
||||||
|
loaded_weight = pad_or_narrow_weight(
|
||||||
|
loaded_weight, self.input_dim, start_idx, shard_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
loaded_weight = loaded_weight.narrow(
|
loaded_weight = loaded_weight.narrow(
|
||||||
self.input_dim, tp_rank * shard_size, shard_size
|
self.input_dim, start_idx, shard_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
|
|||||||
@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
x.dtype,
|
x.dtype,
|
||||||
True, # is_vnni
|
True, # is_vnni
|
||||||
)
|
)
|
||||||
|
|
||||||
x_q, x_scale = per_token_quant_int8(x)
|
x_q, x_scale = per_token_quant_int8(x)
|
||||||
|
|
||||||
return int8_scaled_mm(
|
x_q_2d = x_q.view(-1, x_q.shape[-1])
|
||||||
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
||||||
|
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
||||||
|
|
||||||
|
output = int8_scaled_mm(
|
||||||
|
x_q_2d,
|
||||||
|
layer.weight,
|
||||||
|
x_scale_2d,
|
||||||
|
layer.weight_scale,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return output.view(output_shape)
|
||||||
|
|
||||||
|
|
||||||
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for INT8.
|
"""MoE method for INT8.
|
||||||
|
|||||||
@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def pad_or_narrow_weight(
|
||||||
|
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
|
||||||
|
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
|
||||||
|
|
||||||
|
if valid_size > 0:
|
||||||
|
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
|
||||||
|
pad_shape = list(loaded_weight.shape)
|
||||||
|
pad_shape[input_dim] = shard_size - valid_size
|
||||||
|
pad = torch.zeros(
|
||||||
|
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
||||||
|
)
|
||||||
|
return torch.cat([loaded_slice, pad], dim=input_dim)
|
||||||
|
|
||||||
|
# All padding
|
||||||
|
pad_shape = list(loaded_weight.shape)
|
||||||
|
pad_shape[input_dim] = shard_size
|
||||||
|
return torch.zeros(
|
||||||
|
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PPMissingLayer(torch.nn.Identity):
|
class PPMissingLayer(torch.nn.Identity):
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
||||||
|
|||||||
@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
||||||
self.window_size = vision_config.window_size
|
self.window_size = vision_config.window_size
|
||||||
self.patch_size = vision_config.patch_size
|
self.patch_size = vision_config.patch_size
|
||||||
mlp_hidden_size: int = vision_config.intermediate_size
|
mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
|
||||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
temporal_patch_size=temporal_patch_size,
|
temporal_patch_size=temporal_patch_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user