Support BNB quantization for llama/mllama (#5038)
Co-authored-by: Yuhao Yang <yyh073@foxmail.com>
This commit is contained in:
@@ -1074,7 +1074,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
model_type = model_config.hf_config.model_type
|
model_type = model_config.hf_config.model_type
|
||||||
for quant_param_name in quant_state_dict:
|
for quant_param_name in quant_state_dict:
|
||||||
non_stacked_param_name = quant_param_name
|
non_stacked_param_name = quant_param_name
|
||||||
|
if model_type == "mllama" and "vision_model" in quant_param_name:
|
||||||
|
# adapt to VisionAttention
|
||||||
|
quant_param_name = quant_param_name.replace(
|
||||||
|
"self_attn.o_proj", "self_attn.proj"
|
||||||
|
)
|
||||||
shard_index = 0
|
shard_index = 0
|
||||||
for shard_name, (
|
for shard_name, (
|
||||||
weight_name,
|
weight_name,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: config_mllama.MllamaVisionConfig,
|
config: config_mllama.MllamaVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
is_gated: bool = False,
|
is_gated: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
use_qkv_parallel=True,
|
use_qkv_parallel=True,
|
||||||
quant_config=None,
|
quant_config=quant_config,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
use_context_forward=False,
|
use_context_forward=False,
|
||||||
softmax_in_single_precision=False,
|
softmax_in_single_precision=False,
|
||||||
flatten_batch=False,
|
flatten_batch=False,
|
||||||
prefix=add_prefix("self_attn", prefix),
|
prefix=add_prefix("self_attn", prefix),
|
||||||
)
|
)
|
||||||
self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
|
self.mlp = MllamaVisionMLP(
|
||||||
|
config, quant_config, prefix=add_prefix("mlp", prefix)
|
||||||
|
)
|
||||||
|
|
||||||
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: config_mllama.MllamaVisionConfig,
|
config: config_mllama.MllamaVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
num_layers=32,
|
num_layers=32,
|
||||||
is_gated=False,
|
is_gated=False,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
|
|||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MllamaVisionEncoderLayer(
|
MllamaVisionEncoderLayer(
|
||||||
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
|
config,
|
||||||
|
quant_config,
|
||||||
|
is_gated,
|
||||||
|
prefix=add_prefix(f"layers.{i}", prefix),
|
||||||
)
|
)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
]
|
||||||
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MllamaVisionModel(nn.Module):
|
class MllamaVisionModel(nn.Module):
|
||||||
def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: config_mllama.MllamaVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_size = config.image_size
|
self.image_size = config.image_size
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
|
|||||||
# encoders
|
# encoders
|
||||||
self.transformer = MllamaVisionEncoder(
|
self.transformer = MllamaVisionEncoder(
|
||||||
config,
|
config,
|
||||||
|
quant_config,
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
is_gated=False,
|
is_gated=False,
|
||||||
output_hidden_states=config.intermediate_layers_indices,
|
output_hidden_states=config.intermediate_layers_indices,
|
||||||
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
self.global_transformer = MllamaVisionEncoder(
|
self.global_transformer = MllamaVisionEncoder(
|
||||||
config,
|
config,
|
||||||
|
quant_config,
|
||||||
config.num_global_layers,
|
config.num_global_layers,
|
||||||
is_gated=True,
|
is_gated=True,
|
||||||
prefix=add_prefix("global_transformer", prefix),
|
prefix=add_prefix("global_transformer", prefix),
|
||||||
@@ -765,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MllamaForConditionalGeneration(nn.Module):
|
class MllamaForConditionalGeneration(nn.Module):
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: config_mllama.MllamaConfig,
|
config: config_mllama.MllamaConfig,
|
||||||
@@ -772,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.quant_config = quant_config
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.hidden_size = config.text_config.hidden_size
|
self.hidden_size = config.text_config.hidden_size
|
||||||
self.max_num_tiles = config.vision_config.max_num_tiles
|
self.max_num_tiles = config.vision_config.max_num_tiles
|
||||||
@@ -782,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
self.image_size = config.vision_config.image_size
|
self.image_size = config.vision_config.image_size
|
||||||
|
|
||||||
self.vision_model = MllamaVisionModel(
|
self.vision_model = MllamaVisionModel(
|
||||||
config.vision_config, prefix=add_prefix("vision_model", prefix)
|
config.vision_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("vision_model", prefix),
|
||||||
)
|
)
|
||||||
self.language_model = MllamaForCausalLM(
|
self.language_model = MllamaForCausalLM(
|
||||||
config.text_config,
|
config.text_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("language_model", prefix),
|
prefix=add_prefix("language_model", prefix),
|
||||||
)
|
)
|
||||||
self.multi_modal_projector = nn.Linear(
|
self.multi_modal_projector = ReplicatedLinear(
|
||||||
config.vision_config.vision_output_dim,
|
config.vision_config.vision_output_dim,
|
||||||
config.text_config.hidden_size,
|
config.text_config.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix="multi_modal_projector",
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config.text_config)
|
self.logits_processor = LogitsProcessor(config.text_config)
|
||||||
self.capture_mode = False
|
self.capture_mode = False
|
||||||
@@ -959,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
cross_attention_states = self.vision_model(
|
cross_attention_states = self.vision_model(
|
||||||
batched_images, batched_ar_ids, batched_ar_mask
|
batched_images, batched_ar_ids, batched_ar_mask
|
||||||
)
|
)
|
||||||
cross_attention_states = self.multi_modal_projector(cross_attention_states)
|
cross_attention_states, _ = self.multi_modal_projector(
|
||||||
|
cross_attention_states
|
||||||
|
)
|
||||||
|
|
||||||
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
||||||
cross_attention_states = cross_attention_states.view(
|
cross_attention_states = cross_attention_states.view(
|
||||||
@@ -1013,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
if "vision_model" in name:
|
if "vision_model" in name:
|
||||||
# adapt to VisionAttention
|
# adapt to VisionAttention
|
||||||
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
||||||
|
|
||||||
param = params_dict.pop(name)
|
param = params_dict.pop(name)
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
|
python3 -m unittest test_bnb.TestVisionModel.test_vlm
|
||||||
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion
|
python3 -m unittest test_bnb.TestLanguageModel.test_mmlu
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
@@ -31,10 +31,13 @@ from sglang.test.test_utils import (
|
|||||||
VISION_MODELS = [
|
VISION_MODELS = [
|
||||||
("unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
|
("unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
|
||||||
("unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
|
("unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
|
||||||
|
("unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", "llama_3_vision"),
|
||||||
|
("unsloth/Llama-3.2-11B-Vision-bnb-4bit", "llama_3_vision"),
|
||||||
]
|
]
|
||||||
LANGUAGE_MODELS = [
|
LANGUAGE_MODELS = [
|
||||||
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
|
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
|
||||||
"unsloth/Qwen2-7B-Instruct-bnb-4bit",
|
"unsloth/Qwen2-7B-Instruct-bnb-4bit",
|
||||||
|
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
|
||||||
]
|
]
|
||||||
|
|
||||||
# image
|
# image
|
||||||
|
|||||||
Reference in New Issue
Block a user