Support BNB quantization for llama/mllama (#5038)
Co-authored-by: Yuhao Yang <yyh073@foxmail.com>
This commit is contained in:
@@ -1074,7 +1074,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
model_type = model_config.hf_config.model_type
|
||||
for quant_param_name in quant_state_dict:
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
if model_type == "mllama" and "vision_model" in quant_param_name:
|
||||
# adapt to VisionAttention
|
||||
quant_param_name = quant_param_name.replace(
|
||||
"self_attn.o_proj", "self_attn.proj"
|
||||
)
|
||||
shard_index = 0
|
||||
for shard_name, (
|
||||
weight_name,
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_gated: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
self.num_attention_heads,
|
||||
self.hidden_size,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=None,
|
||||
quant_config=quant_config,
|
||||
dropout=0.0,
|
||||
use_context_forward=False,
|
||||
softmax_in_single_precision=False,
|
||||
flatten_batch=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
|
||||
self.mlp = MllamaVisionMLP(
|
||||
config, quant_config, prefix=add_prefix("mlp", prefix)
|
||||
)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_layers=32,
|
||||
is_gated=False,
|
||||
output_hidden_states=None,
|
||||
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MllamaVisionEncoderLayer(
|
||||
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
|
||||
config,
|
||||
quant_config,
|
||||
is_gated,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
|
||||
|
||||
|
||||
class MllamaVisionModel(nn.Module):
|
||||
def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
|
||||
# encoders
|
||||
self.transformer = MllamaVisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
config.num_hidden_layers,
|
||||
is_gated=False,
|
||||
output_hidden_states=config.intermediate_layers_indices,
|
||||
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
|
||||
)
|
||||
self.global_transformer = MllamaVisionEncoder(
|
||||
config,
|
||||
quant_config,
|
||||
config.num_global_layers,
|
||||
is_gated=True,
|
||||
prefix=add_prefix("global_transformer", prefix),
|
||||
@@ -765,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
|
||||
|
||||
|
||||
class MllamaForConditionalGeneration(nn.Module):
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaConfig,
|
||||
@@ -772,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
self.max_num_tiles = config.vision_config.max_num_tiles
|
||||
@@ -782,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
self.image_size = config.vision_config.image_size
|
||||
|
||||
self.vision_model = MllamaVisionModel(
|
||||
config.vision_config, prefix=add_prefix("vision_model", prefix)
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("vision_model", prefix),
|
||||
)
|
||||
self.language_model = MllamaForCausalLM(
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
self.multi_modal_projector = nn.Linear(
|
||||
self.multi_modal_projector = ReplicatedLinear(
|
||||
config.vision_config.vision_output_dim,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix="multi_modal_projector",
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
self.capture_mode = False
|
||||
@@ -959,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
cross_attention_states = self.vision_model(
|
||||
batched_images, batched_ar_ids, batched_ar_mask
|
||||
)
|
||||
cross_attention_states = self.multi_modal_projector(cross_attention_states)
|
||||
cross_attention_states, _ = self.multi_modal_projector(
|
||||
cross_attention_states
|
||||
)
|
||||
|
||||
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
||||
cross_attention_states = cross_attention_states.view(
|
||||
@@ -1013,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
if "vision_model" in name:
|
||||
# adapt to VisionAttention
|
||||
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
||||
|
||||
param = params_dict.pop(name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user