Support for Qwen2.5-VL Model in bitsandbytes Format (#5003)

This commit is contained in:
yhyang201
2025-04-14 17:03:40 +08:00
committed by GitHub
parent defede5073
commit 072df75354
6 changed files with 375 additions and 45 deletions

View File

@@ -1071,6 +1071,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
model_type = model_config.hf_config.model_type
for quant_param_name in quant_state_dict:
non_stacked_param_name = quant_param_name
@@ -1079,11 +1080,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_name,
index,
) in model.bitsandbytes_stacked_params_mapping.items():
if (
model_type in ["qwen2_vl", "qwen2_5_vl"]
and "visual" in quant_param_name
):
break
if shard_name in quant_param_name:
shard_index = index
quant_param_name = quant_param_name.replace(shard_name, weight_name)
break
if (
model_type in ["qwen2_vl", "qwen2_5_vl"]
and "visual" in quant_param_name
):
quant_param_name = quant_param_name.replace(
r"attn.qkv.", r"attn.qkv_proj."
)
if quant_param_name not in param_dict:
raise ValueError(
f"Parameter {quant_param_name} not found in the model."
@@ -1111,6 +1125,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy(Compatible with vLLM)
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit:

View File

@@ -141,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=False,
use_qkv_parallel=True,
use_context_forward=use_context_forward,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=flatten_batch,
@@ -325,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
@property
def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.gate_proj.weight.dtype
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
@@ -429,6 +429,25 @@ cached_get_processor = lru_cache(get_processor)
class Qwen2_5_VLForConditionalGeneration(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen2_5_VLConfig,
@@ -441,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.visual = Qwen2_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config=None,
# NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
@@ -573,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.hidden_size
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(
3, visual_num_heads, head_size, visual_embed_dim
)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.hidden_size
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")

View File

@@ -152,7 +152,7 @@ class Qwen2VisionBlock(nn.Module):
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
use_qkv_parallel=False,
use_qkv_parallel=True,
use_context_forward=use_context_forward,
softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True,
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
return self.patch_embed.proj.weight.dtype
@property
def device(self) -> torch.device:
@@ -423,6 +423,25 @@ cached_get_processor = lru_cache(get_processor)
class Qwen2VLForConditionalGeneration(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
@@ -447,9 +466,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config=None,
# NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
prefix=add_prefix("visual", prefix),
)
@@ -578,24 +597,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(
3, visual_num_heads, head_size, visual_embed_dim
)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")