Support for Qwen2.5-VL Model in bitsandbytes Format (#5003)
This commit is contained in:
@@ -141,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_qkv_parallel=True,
|
||||
use_context_forward=use_context_forward,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=flatten_batch,
|
||||
@@ -325,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.blocks[0].mlp.gate_proj.weight.dtype
|
||||
return self.patch_embed.proj.weight.dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
@@ -429,6 +429,25 @@ cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2_5_VLConfig,
|
||||
@@ -441,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
# NOTE: Qwen2-VL vision encoder does not support any
|
||||
# quantization method now.
|
||||
quant_config=None,
|
||||
# NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
||||
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
|
||||
@@ -573,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if "visual" in name and "qkv.weight" in name:
|
||||
visual_num_heads = self.config.vision_config.num_heads
|
||||
visual_embed_dim = self.config.vision_config.hidden_size
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(
|
||||
3, visual_num_heads, head_size, visual_embed_dim
|
||||
)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
||||
elif "visual" in name and "qkv.bias" in name:
|
||||
visual_num_heads = self.config.vision_config.num_heads
|
||||
visual_embed_dim = self.config.vision_config.hidden_size
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
if "visual" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
|
||||
@@ -152,7 +152,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_qkv_parallel=True,
|
||||
use_context_forward=use_context_forward,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=True,
|
||||
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(self.parameters()).dtype
|
||||
return self.patch_embed.proj.weight.dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
@@ -423,6 +423,25 @@ cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||
processor = cached_get_processor(self.config._name_or_path)
|
||||
grid_t, grid_h, grid_w = image_grid_thw
|
||||
@@ -447,9 +466,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self.visual = Qwen2VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
# NOTE: Qwen2-VL vision encoder does not support any
|
||||
# quantization method now.
|
||||
quant_config=None,
|
||||
# NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
||||
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
|
||||
@@ -578,24 +597,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
|
||||
if "visual" in name and "qkv.weight" in name:
|
||||
visual_num_heads = self.config.vision_config.num_heads
|
||||
visual_embed_dim = self.config.vision_config.embed_dim
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(
|
||||
3, visual_num_heads, head_size, visual_embed_dim
|
||||
)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
||||
elif "visual" in name and "qkv.bias" in name:
|
||||
visual_num_heads = self.config.vision_config.num_heads
|
||||
visual_embed_dim = self.config.vision_config.embed_dim
|
||||
head_size = visual_embed_dim // visual_num_heads
|
||||
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
if "visual" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||
|
||||
Reference in New Issue
Block a user