From bc24205b32639dd724be49d8214195f757b0fa70 Mon Sep 17 00:00:00 2001 From: ryang <38470282+ryang-max@users.noreply.github.com> Date: Wed, 16 Apr 2025 09:00:31 +0800 Subject: [PATCH] Support BNB quantization for llama/mllama (#5038) Co-authored-by: Yuhao Yang --- python/sglang/srt/model_loader/loader.py | 6 ++- python/sglang/srt/models/mllama.py | 58 ++++++++++++++++++++---- test/srt/test_bnb.py | 7 ++- 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 0a4aeecf6..94b02c6f5 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -1074,7 +1074,11 @@ class BitsAndBytesModelLoader(BaseModelLoader): model_type = model_config.hf_config.model_type for quant_param_name in quant_state_dict: non_stacked_param_name = quant_param_name - + if model_type == "mllama" and "vision_model" in quant_param_name: + # adapt to VisionAttention + quant_param_name = quant_param_name.replace( + "self_attn.o_proj", "self_attn.proj" + ) shard_index = 0 for shard_name, ( weight_name, diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index f8bd9b9b6..aea1fdf71 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor @@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module): def __init__( self, config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, is_gated: bool = False, prefix: str = "", ): @@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module): self.num_attention_heads, self.hidden_size, use_qkv_parallel=True, - quant_config=None, + quant_config=quant_config, dropout=0.0, use_context_forward=False, softmax_in_single_precision=False, flatten_batch=False, prefix=add_prefix("self_attn", prefix), ) - self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix)) + self.mlp = MllamaVisionMLP( + config, quant_config, prefix=add_prefix("mlp", prefix) + ) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) self.post_attention_layernorm = nn.LayerNorm( @@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module): def __init__( self, config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, num_layers=32, is_gated=False, output_hidden_states=None, @@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module): self.layers = nn.ModuleList( [ MllamaVisionEncoderLayer( - config, is_gated, prefix=add_prefix(f"layers.{i}", prefix) + config, + quant_config, + is_gated, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(num_layers) ] @@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module): class MllamaVisionModel(nn.Module): - def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""): + def __init__( + self, + config: config_mllama.MllamaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size @@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module): # encoders self.transformer = MllamaVisionEncoder( config, + quant_config, config.num_hidden_layers, is_gated=False, output_hidden_states=config.intermediate_layers_indices, @@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module): ) self.global_transformer = MllamaVisionEncoder( config, + quant_config, config.num_global_layers, is_gated=True, prefix=add_prefix("global_transformer", prefix), @@ -765,6 +780,27 @@ class MllamaForCausalLM(nn.Module): class MllamaForConditionalGeneration(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, config: config_mllama.MllamaConfig, @@ -772,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module): prefix: str = "", ): super().__init__() + self.quant_config = quant_config self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles @@ -782,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module): self.image_size = config.vision_config.image_size self.vision_model = MllamaVisionModel( - config.vision_config, prefix=add_prefix("vision_model", prefix) + config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_model", prefix), ) self.language_model = MllamaForCausalLM( config.text_config, quant_config=quant_config, prefix=add_prefix("language_model", prefix), ) - self.multi_modal_projector = nn.Linear( + self.multi_modal_projector = ReplicatedLinear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, + quant_config=quant_config, + prefix="multi_modal_projector", ) self.logits_processor = LogitsProcessor(config.text_config) self.capture_mode = False @@ -959,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module): cross_attention_states = self.vision_model( batched_images, batched_ar_ids, batched_ar_mask ) - cross_attention_states = self.multi_modal_projector(cross_attention_states) + cross_attention_states, _ = self.multi_modal_projector( + cross_attention_states + ) bs, _, _, _, image_token_dim = cross_attention_states.shape cross_attention_states = cross_attention_states.view( @@ -1013,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module): if "vision_model" in name: # adapt to VisionAttention name = name.replace("self_attn.o_proj", "self_attn.proj") - param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/test/srt/test_bnb.py b/test/srt/test_bnb.py index 91056dfd3..4a117e249 100644 --- a/test/srt/test_bnb.py +++ b/test/srt/test_bnb.py @@ -1,7 +1,7 @@ """ Usage: -python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch -python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion +python3 -m unittest test_bnb.TestVisionModel.test_vlm +python3 -m unittest test_bnb.TestLanguageModel.test_mmlu """ import base64 @@ -31,10 +31,13 @@ from sglang.test.test_utils import ( VISION_MODELS = [ ("unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "qwen2-vl"), ("unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "qwen2-vl"), + ("unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", "llama_3_vision"), + ("unsloth/Llama-3.2-11B-Vision-bnb-4bit", "llama_3_vision"), ] LANGUAGE_MODELS = [ "unsloth/Qwen2.5-7B-Instruct-bnb-4bit", "unsloth/Qwen2-7B-Instruct-bnb-4bit", + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", ] # image