diff --git a/docs/backend/quantization.md b/docs/backend/quantization.md index cdec87bac..b70f34f4a 100644 --- a/docs/backend/quantization.md +++ b/docs/backend/quantization.md @@ -2,15 +2,25 @@ SGLang supports various quantization methods, including offline quantization and online dynamic quantization. -Offline quantization loads pre-quantized model weights directly during inference. This is useful for methods requiring pre-computed stats such as AWQ, which collects activation stats from the pre-training set. +Offline quantization loads pre-quantized model weights directly during inference. This is required for quantization methods +such as GPTQ and AWQ that collects and pre-compute various stats from the original weights using the calibration dataset. -Online quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime. Like NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors on-the-fly to convert high-precision weights into a lower-precision format. +Online quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime. +Like NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors +on-the-fly to convert high-precision weights into a lower-precision format. -**Note that, for better performance, usability and convenience, offline quantization is recommended over online quantization.** And if you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time. For popular pre-quantized models, please visit [neuralmagic collection](https://huggingface.co/collections/neuralmagic) for some popular quantized LLMs on huggingface. +**Note: For better performance, usability and convenience, offline quantization is recommended over online quantization.** + +If you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time. +For popular pre-quantized models, please visit [ModelCloud](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2) or [NeuralMagic](https://huggingface.co/collections/neuralmagic) collections on HF for some +popular quality validated quantized models. Quantized models must be validated via benchmarks post-quantization +to guard against abnormal quantization loss regressions. ## Offline Quantization -To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.** +To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline, +there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the +downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.** ```bash python3 -m sglang.launch_server \ @@ -18,9 +28,38 @@ python3 -m sglang.launch_server \ --port 30000 --host 0.0.0.0 ``` -To do offline quantization for your model, firstly you need to install [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: +### Examples of Offline Model Quantization + +#### Using [GPTQModel](https://github.com/ModelCloud/GPTQModel) ```bash +# install +pip install gptqmodel --no-build-isolation -v +``` + +```py +from datasets import load_dataset +from gptqmodel import GPTQModel, QuantizeConfig + +model_id = "meta-llama/Llama-3.2-1B-Instruct" +quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit" + +calibration_dataset = load_dataset( + "allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz", + split="train" + ).select(range(1024))["text"] + +quant_config = QuantizeConfig(bits=4, group_size=128) # quantization config +model = GPTQModel.load(model_id, quant_config) # load model + +model.quantize(calibration_dataset, batch_size=2) # quantize +model.save(quant_path) # save model +``` + +#### Using [LLM Compressor](https://github.com/vllm-project/llm-compressor/) + +```bash +# install pip install llmcompressor ``` @@ -99,8 +138,7 @@ python3 -m sglang.launch_server \ ## Reference -- [quantization document of vllm](https://docs.vllm.ai/en/latest/quantization/fp8.html) - -- [torchao](https://github.com/pytorch/ao) - -- [llm-compressor](https://github.com/vllm-project/llm-compressor/) +- [GPTQModel](https://github.com/ModelCloud/GPTQModel) +- [LLM Compressor](https://github.com/vllm-project/llm-compressor/) +- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao) +- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 0377277e6..f90673191 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -19,6 +19,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.utils import add_prefix def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: @@ -122,20 +123,20 @@ class VisionAttention(nn.Module): head_size=self.head_size, total_num_heads=num_heads, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) else: self.qkv_proj = ColumnParallelLinear( input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) self.proj = RowParallelLinear( input_size=embed_dim, output_size=embed_dim, quant_config=quant_config, - prefix=f"{prefix}.out_proj", + prefix=add_prefix("out_proj", prefix), ) def forward( diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 12391d4e6..a0e38b022 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -417,7 +417,7 @@ class LogitsProcessor(nn.Module): ) else: # GGUF models - logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) + logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias) if self.logit_scale is not None: logits.mul_(self.logit_scale) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 928643b70..1ef8f4381 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,5 +1,7 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py -from typing import Callable, Dict, Optional, Type +import re +from copy import deepcopy +from typing import Callable, Dict, Optional, Type, Union import torch from vllm.model_executor.layers.quantization.aqlm import AQLMConfig @@ -16,8 +18,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.gguf import GGUFConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig @@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.fp8 import Fp8Config +from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -61,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] +# Match dynamic rules with module name (prefix) and override quantize +# config if module (prefix) matches a rule +def override_config(config: QuantizationConfig, prefix: str): + weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) + if isinstance(weight_bits, int): + config.weight_bits = weight_bits + group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) + if isinstance(group_size, int): + config.group_size = group_size + desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) + if isinstance(desc_act, bool): + config.desc_act = desc_act + + config.pack_factor = 32 // config.weight_bits # packed into int32 + if config.get_name() == "gptq_marlin": + is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym) + if isinstance(is_sym, bool): + config.is_sym = is_sym + + if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: + raise ValueError( + "Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}" + ) + + config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] + elif config.get_name() == "gptq": + if config.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {config.weight_bits} bits." + ) + + +def get_dynamic_override( + config: QuantizationConfig, + layer_name: str, + key: Optional[str] = None, + default_value: Union[int, bool, None] = None, +) -> Union[Dict, int, bool, None]: + for pattern, pattern_dict in config.dynamic.items(): + # Negative match: matched modules are excluded from quantized init + if pattern.startswith("-:"): + if re.match(pattern.removeprefix("-:"), layer_name): + return False + # Positive match: matched modules have quant properties overrides + # base quant config + elif re.match(pattern.removeprefix("+:"), layer_name): + if key is None: + return pattern_dict + else: + return pattern_dict.get(key, default_value) + return default_value + + +def get_linear_quant_method( + config: QuantizationConfig, + layer: torch.nn.Module, + prefix: str, + linear_method_cls: type, +): + + from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod + from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + UnquantizedEmbeddingMethod, + ) + + cloned_config = deepcopy(config) + parallel_lm_head_quantized = ( + isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized + ) + + if isinstance(layer, LinearBase) or parallel_lm_head_quantized: + # False = skip module, None = no override, else = Positive match + if ( + get_dynamic_override( # noqa: E712 + cloned_config, layer_name=prefix # noqa: E712 + ) + == False + ): # noqa: E712 + if parallel_lm_head_quantized: + return UnquantizedEmbeddingMethod() + return UnquantizedLinearMethod() + + if prefix: + # Dynamic per module/layer rules may override base config + override_config(cloned_config, prefix=prefix) + + return linear_method_cls(cloned_config) + return None + + def gptq_get_quant_method(self, layer, prefix): + from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) - from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE - if isinstance(layer, LinearBase): - return GPTQMarlinLinearMethod(self) - elif isinstance(layer, FusedMoE): + if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) + + if isinstance(self, GPTQConfig): + return get_linear_quant_method( + self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod + ) + elif isinstance(self, GPTQMarlinConfig): + return get_linear_quant_method( + self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod + ) return None @@ -155,6 +256,7 @@ def apply_monkey_patches(): from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) + setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) setattr(AWQMoEMethod, "apply", awq_moe_method_apply) diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py new file mode 100644 index 000000000..b15498864 --- /dev/null +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -0,0 +1,416 @@ +import logging +from fractions import Fraction +from typing import Any, Dict, List, Optional, Union + +import torch +from vllm.scalar_type import scalar_types + +from sglang.srt.layers.linear import LinearBase +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead + +logger = logging.getLogger(__name__) + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], + ) -> None: + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + super().__init__() + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.pack_factor = Fraction(32, self.weight_bits) + if self.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {self.weight_bits} bits." + ) + + def __repr__(self) -> str: + return ( + f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})," + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}" + ) + + def get_scaled_act_names(self) -> List[str]: + """Returns the activation function names that should be post-scaled. + + For now, this is only used by AWQ. + """ + raise NotImplementedError + + @classmethod + def get_name(cls) -> str: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["GPTQLinearMethod"]: + from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod + + from sglang.srt.layers.quantization import get_linear_quant_method + + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + + +class GPTQMarlinConfig(QuantizationConfig): + """Config class for GPTQ Marlin""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + dynamic: Dict[str, Dict[str, Union[int, bool]]], + full_config: Dict[str, Any], + ) -> None: + super().__init__() + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + # GPTQModel use `dynamic` config property to allow per module + # quantization config so each module can be individually optimized. + # Format is Dict[str, Dict] where key is a regex string that can + # perform both positive ("+:" prefixed) or negative ("-:" prefixed) + # matching of a module. + # Default to positive match, override base quant config mode, if no + # prefix is used. Value is in dict format of field key and override + # value. + # Negative matching will skip quantization init for this module + # entirely: + # non-quantized inference. More details and quantization examples can be + # found at: https://github.com/ModelCloud/GPTQModel + # Example: + # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 + # # last 1/4 of the layers 16-21 has 8bit and group_size 64 + # dynamic = { + # #`.*\.` matches the layers_node prefix + # # positive match layer 10-15 + # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, + # # positive match layer 16-21 + # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, + # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers + # } + self.dynamic = dynamic + + self.weight_bits = weight_bits + self.is_sym = is_sym + + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.full_config = full_config + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError( + "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}" + ) + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return ( + f"GPTQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized}), " + f"dynamic={self.dynamic}" + ) + + def get_scaled_act_names(self) -> List[str]: + """Returns the activation function names that should be post-scaled. + + For now, this is only used by AWQ. + """ + raise NotImplementedError + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) + dynamic = {} if dynamic is None else dynamic + + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls( + weight_bits, + group_size, + desc_act, + is_sym, + lm_head_quantized, + dynamic, + config, + ) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) + + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" + ) + + if can_convert and is_valid_user_quant: + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info( + "Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference" + ) + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, + ) + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.quantization import get_linear_quant_method + + if isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) + # TODO: re-enable after SGLang syncs with vllm >= 0.7.3 + # if layer.num_experts > 32: + # # For MoEs with many experts the moe_wna16 kernel is faster + # return MoeWNA16Config.from_config(self.full_config).get_quant_method( + # layer, prefix + # ) + # else: + # return GPTQMarlinMoEMethod(self) + return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) + + @classmethod + def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, + ) + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + return False + + if quant_method != "gptq": + return False + + # Marlin conversion is only valid if required properties are found + if num_bits is None or group_size is None or sym is None or desc_act is None: + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_marlin_supported( + quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size + ) + + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + lm_head_quantized: bool, + ) -> None: + # Group size for the quantization. + self.group_size = group_size + self.lm_head_quantized = lm_head_quantized + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin, but got group_size of " + f"{self.group_size}" + ) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // 4 + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return ( + f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})" + ) + + @classmethod + def get_name(cls) -> str: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls(group_size, lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = hf_quant_cfg.get( + "checkpoint_format" + ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False) + + is_valid_user_quant = ( + user_quant is None or user_quant == "gptq" or user_quant == "marlin" + ) + + if is_marlin_format and is_valid_user_quant: + msg = "The model is serialized in {} format. Using {} kernel.".format( + cls.get_name(), cls.get_name() + ) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["MarlinLinearMethod"]: + from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod + + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): + return MarlinLinearMethod(self) + return None diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 0d46e7bba..507f6e950 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -34,6 +34,7 @@ class RadixAttention(nn.Module): v_head_dim: int = -1, sliding_window_size: int = -1, is_cross_attention: bool = False, + prefix: str = "", ): super().__init__() self.tp_q_head_num = num_heads diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 22229b643..08b487262 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module): ) self.embedding_dim = embedding_dim - linear_method = None + quant_method = None if quant_config is not None: - linear_method = quant_config.get_quant_method(self, prefix=prefix) - if linear_method is None: - linear_method = UnquantizedEmbeddingMethod() + quant_method = quant_config.get_quant_method(self, prefix=prefix) + print("quant_method", quant_method) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() # If we are making an embedding layer, then our quantization linear # method must implement the embedding operation. If we are another # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self.__class__) is VocabParallelEmbedding - linear_method_implements_embedding = method_has_implemented_embedding( - type(linear_method) + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method) ) - if is_embedding_layer and not linear_method_implements_embedding: + if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( - f"The class {type(linear_method).__name__} must implement " + f"The class {type(quant_method).__name__} must implement " "the 'embedding' method, see UnquantizedEmbeddingMethod." ) - self.linear_method: QuantizeMethodBase = linear_method + self.quant_method: QuantizeMethodBase = quant_method if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module): - self.shard_indices.added_vocab_start_index ) - self.linear_method.create_weights( + self.quant_method.create_weights( self, self.embedding_dim, [self.num_embeddings_per_partition], @@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module): packed_factor = ( param.packed_factor if isinstance(param, BasevLLMParameter) - else param.pack_factor + else param.packed_factor ) assert loaded_weight.shape[output_dim] == ( self.org_vocab_size // param.packed_factor @@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module): else: masked_input = input_ # Get the embeddings. - output_parallel = self.linear_method.embedding(self, masked_input.long()) + output_parallel = self.quant_method.embedding(self, masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 066157f05..578935012 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -80,13 +81,22 @@ class BaiChuanMLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, quant_config=quant_config + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -114,6 +124,7 @@ class BaiChuanAttention(nn.Module): max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, layer_id: int = 0, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -167,6 +178,7 @@ class BaiChuanAttention(nn.Module): scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) else: self.rotary_emb = get_rope( @@ -182,6 +194,7 @@ class BaiChuanAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -207,6 +220,7 @@ class BaiChuanDecoderLayer(nn.Module): position_embedding: str, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -220,12 +234,14 @@ class BaiChuanDecoderLayer(nn.Module): layer_id=layer_id, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -264,6 +280,7 @@ class BaiChuanModel(nn.Module): config: PretrainedConfig, position_embedding: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -281,6 +298,7 @@ class BaiChuanModel(nn.Module): layer_id=i, position_embedding=position_embedding, quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -330,18 +348,24 @@ class BaiChuanBaseForCausalLM(nn.Module): config: PretrainedConfig, position_embedding: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.quant_config = quant_config - self.model = BaiChuanModel(config, position_embedding, quant_config) + self.model = BaiChuanModel( + config, position_embedding, quant_config, prefix=add_prefix("model", prefix) + ) if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) @@ -404,11 +428,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", quant_config) + super().__init__(config, "ROPE", quant_config, prefix=prefix) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", quant_config) + super().__init__(config, "ALIBI", quant_config, prefix=prefix) EntryClass = [BaichuanForCausalLM] diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 222cc3e2d..4692a5812 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -41,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix LoraConfig = None @@ -51,6 +52,7 @@ class GLMAttention(nn.Module): config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -85,12 +87,14 @@ class GLMAttention(nn.Module): self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, quant_config=quant_config, + prefix=add_prefix("query_key_value", prefix), ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, + prefix=add_prefix("dense", prefix), ) # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 @@ -109,6 +113,7 @@ class GLMAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -142,6 +147,7 @@ class GLMMLP(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -153,6 +159,7 @@ class GLMMLP(nn.Module): [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, quant_config=quant_config, + prefix=add_prefix("dense_h_to_4h", prefix), ) self.activation_func = SiluAndMul() @@ -163,6 +170,7 @@ class GLMMLP(nn.Module): config.hidden_size, bias=config.add_bias_linear, quant_config=quant_config, + prefix=add_prefix("dense_4h_to_h", prefix), ) def forward(self, hidden_states): @@ -186,6 +194,7 @@ class GLMBlock(nn.Module): config, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -201,7 +210,9 @@ class GLMBlock(nn.Module): ) # Self attention. - self.self_attention = GLMAttention(config, layer_id, quant_config) + self.self_attention = GLMAttention( + config, layer_id, quant_config, prefix=add_prefix("self_attention", prefix) + ) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -210,7 +221,7 @@ class GLMBlock(nn.Module): ) # MLP - self.mlp = GLMMLP(config, quant_config) + self.mlp = GLMMLP(config, quant_config, prefix=add_prefix("mlp", prefix)) def forward( self, @@ -257,6 +268,7 @@ class GLMTransformer(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -266,7 +278,15 @@ class GLMTransformer(nn.Module): # Transformer layers. self.layers = nn.ModuleList( - [GLMBlock(config, i, quant_config) for i in range(self.num_layers)] + [ + GLMBlock( + config, + i, + quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) + for i in range(self.num_layers) + ] ) if self.post_layer_norm: @@ -301,19 +321,28 @@ class ChatGLMM(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.embedding = VocabParallelEmbedding( - config.padded_vocab_size, config.hidden_size + config.padded_vocab_size, + config.hidden_size, + prefix=add_prefix("embedding", prefix), ) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, quant_config) + self.encoder = GLMTransformer( + config, quant_config, add_prefix("encoder", prefix) + ) - self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) + self.output_layer = ParallelLMHead( + config.padded_vocab_size, + config.hidden_size, + prefix=add_prefix("output_layer", prefix), + ) def forward( self, @@ -351,12 +380,15 @@ class ChatGLMForCausalLM(nn.Module): self, config: ChatGLMConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) - self.transformer = ChatGLMM(config, quant_config) + self.transformer = ChatGLMM( + config, quant_config, prefix=add_prefix("transformer", prefix) + ) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index e4b291b66..7cdf0e135 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -65,7 +65,7 @@ from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import get_compiler_backend, set_weight_attrs +from sglang.srt.utils import add_prefix, get_compiler_backend, set_weight_attrs @torch.compile(backend=get_compiler_backend()) @@ -110,6 +110,7 @@ class CohereMLP(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -120,12 +121,14 @@ class CohereMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) self.act_fn = SiluAndMul() @@ -142,6 +145,7 @@ class CohereAttention(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -177,12 +181,14 @@ class CohereAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -198,6 +204,7 @@ class CohereAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) if self.use_qk_norm: self.q_norm = LayerNorm( @@ -239,15 +246,23 @@ class CohereDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size self.self_attn = CohereAttention( - config, layer_id=layer_id, quant_config=quant_config + config, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) - self.mlp = CohereMLP(config, quant_config=quant_config) + self.mlp = CohereMLP( + config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) self.input_layernorm = LayerNorm( param_shape=(config.hidden_size), eps=config.layer_norm_eps ) @@ -279,6 +294,7 @@ class CohereModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -288,7 +304,12 @@ class CohereModel(nn.Module): ) self.layers = nn.ModuleList( [ - CohereDecoderLayer(config, i, quant_config=quant_config) + CohereDecoderLayer( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -321,12 +342,15 @@ class CohereForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config self.logits_processor = LogitsProcessor(config) - self.model = CohereModel(config, quant_config) + self.model = CohereModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) @torch.no_grad() def forward( diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 92fc67939..b1bc79872 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import add_prefix, set_weight_attrs class DbrxRouter(nn.Module): @@ -58,6 +58,7 @@ class DbrxRouter(nn.Module): self, config: DbrxConfig, params_dtype: Optional[torch.dtype] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -89,6 +90,7 @@ class DbrxExperts(nn.Module): config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -189,6 +191,7 @@ class DbrxAttention(nn.Module): config: DbrxConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -207,12 +210,14 @@ class DbrxAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("Wqkv", prefix), ) self.out_proj = RowParallelLinear( self.d_model, self.d_model, bias=False, quant_config=quant_config, + prefix=add_prefix("out_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -244,6 +249,7 @@ class DbrxAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -268,10 +274,16 @@ class DbrxFusedNormAttention(nn.Module): config: DbrxConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, layer_id, quant_config=quant_config) + self.attn = DbrxAttention( + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -300,10 +312,14 @@ class DbrxBlock(nn.Module): config: DbrxConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.norm_attn_norm = DbrxFusedNormAttention( - config, layer_id, quant_config=quant_config + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix("norm_attn_norm", prefix), ) self.ffn = DbrxExperts(config, quant_config=quant_config) @@ -328,6 +344,7 @@ class DbrxModel(nn.Module): self, config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.wte = VocabParallelEmbedding( @@ -336,7 +353,12 @@ class DbrxModel(nn.Module): ) self.blocks = nn.ModuleList( [ - DbrxBlock(config, i, quant_config=quant_config) + DbrxBlock( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"blocks.{i}", prefix), + ) for i in range(config.n_layers) ] ) @@ -369,17 +391,21 @@ class DbrxForCausalLM(nn.Module): self, config: DbrxConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, quant_config=quant_config) + self.transformer = DbrxModel( + config, quant_config=quant_config, prefix=add_prefix("transformer", prefix) + ) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 7d2c0700f..216aca9c2 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class DeepseekMLP(nn.Module): @@ -57,10 +58,15 @@ class DeepseekMLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, @@ -68,6 +74,7 @@ class DeepseekMLP(nn.Module): bias=False, quant_config=quant_config, reduce_results=reduce_results, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -89,6 +96,7 @@ class DeepseekMoE(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -110,6 +118,7 @@ class DeepseekMoE(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=add_prefix(f"{idx}.experts", prefix), ) for idx in range(self.n_routed_experts) ] @@ -117,7 +126,11 @@ class DeepseekMoE(nn.Module): self.pack_params() self.gate = ReplicatedLinear( - config.hidden_size, self.n_routed_experts, bias=False, quant_config=None + config.hidden_size, + self.n_routed_experts, + bias=False, + quant_config=None, + prefix=add_prefix("gate", prefix), ) if config.n_shared_experts is not None: @@ -128,6 +141,7 @@ class DeepseekMoE(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=add_prefix("shared_experts", prefix), ) def pack_params(self): @@ -185,6 +199,7 @@ class DeepseekAttention(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -216,6 +231,7 @@ class DeepseekAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( @@ -223,6 +239,7 @@ class DeepseekAttention(nn.Module): hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -238,6 +255,7 @@ class DeepseekAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -261,6 +279,7 @@ class DeepseekDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -276,19 +295,25 @@ class DeepseekDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) if ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace and layer_id % config.moe_layer_freq == 0 ): - self.mlp = DeepseekMoE(config=config, quant_config=quant_config) + self.mlp = DeepseekMoE( + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -328,6 +353,7 @@ class DeepseekModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -339,7 +365,12 @@ class DeepseekModel(nn.Module): ) self.layers = nn.ModuleList( [ - DeepseekDecoderLayer(config, layer_id, quant_config=quant_config) + DeepseekDecoderLayer( + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_id}", prefix), + ) for layer_id in range(config.num_hidden_layers) ] ) @@ -368,13 +399,19 @@ class DeepseekForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekModel(config, quant_config) + self.model = DeepseekModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 0dfa69a2e..c1c99be54 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -38,7 +38,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM -from sglang.srt.utils import is_hip +from sglang.srt.utils import add_prefix, is_hip is_hip_ = is_hip() @@ -48,6 +48,7 @@ class DeepseekModelNextN(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.vocab_size = config.vocab_size @@ -56,6 +57,7 @@ class DeepseekModelNextN(nn.Module): config.vocab_size, config.hidden_size, enable_tp=not global_server_args_dict["enable_dp_attention"], + prefix=add_prefix("embed_tokens", prefix), ) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -64,7 +66,11 @@ class DeepseekModelNextN(nn.Module): self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) self.decoder = DeepseekV2DecoderLayer( - config, 0, quant_config=quant_config, is_nextn=True + config, + 0, + quant_config=quant_config, + is_nextn=True, + prefix=add_prefix("decoder", prefix), ) self.shared_head = nn.Module() @@ -108,18 +114,22 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: nn.Module.__init__(self) self.config = config self.quant_config = quant_config - self.model = DeepseekModelNextN(config, quant_config) + self.model = DeepseekModelNextN( + config, quant_config, prefix=add_prefix("model", prefix) + ) if global_server_args_dict["enable_dp_attention"]: self.lm_head = ReplicatedLinear( config.hidden_size, config.vocab_size, bias=False, + prefix=add_prefix("model.shared_head.head", prefix), ) self.logits_processor = LogitsProcessor(config, skip_all_gather=True) else: @@ -127,6 +137,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 892d8f594..18d38ccd0 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -63,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_cuda_available, is_hip +from sglang.srt.utils import add_prefix, is_cuda_available, is_hip is_hip_ = is_hip() @@ -79,10 +79,15 @@ class DeepseekV2MLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, @@ -90,6 +95,7 @@ class DeepseekV2MLP(nn.Module): bias=False, quant_config=quant_config, reduce_results=reduce_results, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -106,7 +112,11 @@ class DeepseekV2MLP(nn.Module): class MoEGate(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + prefix: str = "", + ): super().__init__() self.weight = nn.Parameter( torch.empty((config.n_routed_experts, config.hidden_size)) @@ -129,6 +139,7 @@ class DeepseekV2MoE(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -147,7 +158,7 @@ class DeepseekV2MoE(nn.Module): "Only silu is supported for now." ) - self.gate = MoEGate(config=config) + self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix)) MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE self.experts = MoEImpl( @@ -161,6 +172,7 @@ class DeepseekV2MoE(nn.Module): num_expert_group=config.n_group, topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias, + prefix=add_prefix("experts", prefix), ) if config.n_shared_experts is not None: @@ -171,6 +183,7 @@ class DeepseekV2MoE(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=add_prefix("shared_experts", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -217,6 +230,7 @@ class DeepseekV2Attention(nn.Module): max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, layer_id=None, + prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id @@ -241,6 +255,7 @@ class DeepseekV2Attention(nn.Module): self.q_lora_rank, bias=False, quant_config=quant_config, + prefix=add_prefix("q_a_proj", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -248,6 +263,7 @@ class DeepseekV2Attention(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_b_proj", prefix), ) else: self.q_proj = ColumnParallelLinear( @@ -255,6 +271,7 @@ class DeepseekV2Attention(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_proj", prefix), ) self.kv_a_proj_with_mqa = ReplicatedLinear( @@ -262,8 +279,7 @@ class DeepseekV2Attention(nn.Module): self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - # FIXME: quick fix for skip quantization - prefix=f"self_attn.kv_a_proj_with_mqa", + prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( @@ -271,6 +287,7 @@ class DeepseekV2Attention(nn.Module): self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, + prefix=add_prefix("kv_b_proj", prefix), ) # O projection. self.o_proj = RowParallelLinear( @@ -278,6 +295,7 @@ class DeepseekV2Attention(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) rope_scaling["rope_type"] = "deepseek_yarn" self.rotary_emb = get_rope_wrapper( @@ -303,6 +321,7 @@ class DeepseekV2Attention(nn.Module): self.scaling, num_kv_heads=self.num_local_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -368,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module): quant_config: Optional[QuantizationConfig] = None, layer_id=None, use_dp=False, + prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id @@ -394,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.q_lora_rank, bias=False, quant_config=quant_config, + prefix=add_prefix("q_a_proj", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ReplicatedLinear( @@ -401,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_b_proj", prefix), ) else: self.q_proj = ReplicatedLinear( @@ -408,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_proj", prefix), ) self.kv_b_proj = ReplicatedLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, + prefix=add_prefix("kv_b_proj", prefix), ) # O projection. self.o_proj = ReplicatedLinear( @@ -421,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) else: # For tensor parallel attention @@ -430,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.q_lora_rank, bias=False, quant_config=quant_config, + prefix=add_prefix("q_a_proj", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -437,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_b_proj", prefix), ) else: self.q_proj = ColumnParallelLinear( @@ -444,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_proj", prefix), ) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, + prefix=add_prefix("kv_b_proj", prefix), ) # O projection. self.o_proj = RowParallelLinear( @@ -457,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.kv_a_proj_with_mqa = ReplicatedLinear( @@ -464,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - # FIXME: quick fix for skip quantization - prefix=f"self_attn.kv_a_proj_with_mqa", + prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) @@ -496,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module): num_kv_heads=1, layer_id=layer_id, v_head_dim=self.kv_lora_rank, + prefix=add_prefix("attn_mqa", prefix), ) self.attn_mha = RadixAttention( @@ -505,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module): num_kv_heads=self.num_local_heads, layer_id=layer_id, v_head_dim=self.v_head_dim, + prefix=add_prefix("attn_mha", prefix), ) self.w_kc = None @@ -848,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module): layer_id: int, quant_config: Optional[QuantizationConfig] = None, is_nextn: bool = False, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -880,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config=quant_config, layer_id=layer_id, use_dp=self.enable_dp_attention, + prefix=add_prefix("self_attn", prefix), ) else: self.self_attn = DeepseekV2Attention( @@ -898,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module): max_position_embeddings=max_position_embeddings, quant_config=quant_config, layer_id=layer_id, + prefix=add_prefix("self_attn", prefix), ) if is_nextn or ( config.n_routed_experts is not None and layer_id >= config.first_k_dense_replace and layer_id % config.moe_layer_freq == 0 ): - self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config) + self.mlp = DeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) else: self.mlp = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -962,6 +1001,7 @@ class DeepseekV2Model(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_id = config.pad_token_id @@ -978,6 +1018,7 @@ class DeepseekV2Model(nn.Module): config, layer_id, quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_id}", prefix), ) for layer_id in range(config.num_hidden_layers) ] @@ -1008,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekV2Model(config, quant_config) + self.model = DeepseekV2Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) if global_server_args_dict["enable_dp_attention"]: self.lm_head = ReplicatedLinear( config.hidden_size, config.vocab_size, bias=False, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config, skip_all_gather=True) else: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 10be1e74d..5b301c801 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class ExaoneGatedMLP(nn.Module): @@ -56,14 +57,14 @@ class ExaoneGatedMLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", + prefix=add_prefix("gate_up_proj", prefix), ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.c_proj", + prefix=add_prefix("c_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -130,14 +131,14 @@ class ExaoneAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) self.out_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.out_proj", + prefix=add_prefix("out_proj", prefix), ) self.rotary_emb = get_rope( @@ -201,14 +202,14 @@ class ExaoneDecoderLayer(nn.Module): rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, - prefix=f"{prefix}.self_attn", + prefix=add_prefix("self_attn", prefix), ) self.mlp = ExaoneGatedMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.activation_function, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=add_prefix("mlp", prefix), ) rms_norm_eps = config.layer_norm_epsilon self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps) @@ -244,6 +245,7 @@ class ExaoneModel(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -256,7 +258,10 @@ class ExaoneModel(nn.Module): self.h = nn.ModuleList( [ ExaoneDecoderLayer( - config, i, quant_config=quant_config, prefix=f"model.h.{i}" + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"h.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -293,12 +298,17 @@ class ExaoneForCausalLM(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.transformer = ExaoneModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.transformer = ExaoneModel( + config, quant_config=quant_config, prefix=add_prefix("transformer", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 811ce9d51..8ab8abd4f 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,6 +37,7 @@ from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class GemmaMLP(nn.Module): @@ -45,6 +46,7 @@ class GemmaMLP(nn.Module): hidden_size: int, intermediate_size: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -52,12 +54,14 @@ class GemmaMLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) self.act_fn = GeluAndMul("none") @@ -79,6 +83,7 @@ class GemmaAttention(nn.Module): max_position_embeddings: int = 8192, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -109,12 +114,14 @@ class GemmaAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -130,6 +137,7 @@ class GemmaAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -152,6 +160,7 @@ class GemmaDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -164,11 +173,13 @@ class GemmaDecoderLayer(nn.Module): max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) self.mlp = GemmaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -205,6 +216,7 @@ class GemmaModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -215,7 +227,12 @@ class GemmaModel(nn.Module): ) self.layers = nn.ModuleList( [ - GemmaDecoderLayer(config, i, quant_config=quant_config) + GemmaDecoderLayer( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -277,11 +294,14 @@ class GemmaForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = GemmaModel(config, quant_config=quant_config) + self.model = GemmaModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 76b04cc86..87cd7dbe0 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -39,7 +39,7 @@ from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import make_layers +from sglang.srt.utils import add_prefix, make_layers # Aligned with HF's implementation, using sliding window inclusive with the last token @@ -56,13 +56,22 @@ class Gemma2MLP(nn.Module): hidden_act: str, hidden_activation: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( - intermediate_size, hidden_size, bias=False, quant_config=quant_config + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): raise ValueError( @@ -91,6 +100,7 @@ class Gemma2Attention(nn.Module): max_position_embeddings: int, rope_theta: float, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id @@ -123,12 +133,14 @@ class Gemma2Attention(nn.Module): self.total_num_kv_heads, bias=config.attention_bias, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=config.attention_bias, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -151,6 +163,7 @@ class Gemma2Attention(nn.Module): if use_sliding_window else None ), + prefix=add_prefix("attn", prefix), ) def forward( @@ -173,6 +186,7 @@ class Gemma2DecoderLayer(nn.Module): layer_id: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -186,6 +200,7 @@ class Gemma2DecoderLayer(nn.Module): max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) self.hidden_size = config.hidden_size self.mlp = Gemma2MLP( @@ -194,6 +209,7 @@ class Gemma2DecoderLayer(nn.Module): hidden_act=config.hidden_act, hidden_activation=config.hidden_activation, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm( @@ -238,6 +254,7 @@ class Gemma2Model(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -253,7 +270,7 @@ class Gemma2Model(nn.Module): config=config, quant_config=quant_config, ), - prefix="", + prefix=add_prefix("layers", prefix), ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -339,11 +356,14 @@ class Gemma2ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Gemma2Model(config, quant_config) + self.model = Gemma2Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/gemma2_reward.py b/python/sglang/srt/models/gemma2_reward.py index 1fe87c30a..03bea4d10 100644 --- a/python/sglang/srt/models/gemma2_reward.py +++ b/python/sglang/srt/models/gemma2_reward.py @@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.gemma2 import Gemma2ForCausalLM, Gemma2Model +from sglang.srt.utils import add_prefix class Gemma2ForSequenceClassification(nn.Module): @@ -29,12 +30,15 @@ class Gemma2ForSequenceClassification(nn.Module): self, config: Gemma2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config self.num_labels = config.num_labels - self.model = Gemma2Model(config, quant_config=quant_config) + self.model = Gemma2Model( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index c9b78e6f6..15374afaa 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -36,6 +36,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class GPT2Attention(nn.Module): @@ -62,14 +63,14 @@ class GPT2Attention(nn.Module): total_num_heads, bias=True, quant_config=quant_config, - prefix=f"{prefix}.c_attn", + prefix=add_prefix("c_attn", prefix), ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.c_proj", + prefix=add_prefix("c_proj", prefix), ) self.attn = RadixAttention( self.num_heads, @@ -108,14 +109,14 @@ class GPT2MLP(nn.Module): intermediate_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.c_fc", + prefix=add_prefix("c_fc", prefix), ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.c_proj", + prefix=add_prefix("c_proj", prefix), ) self.act = act_layer() @@ -145,7 +146,7 @@ class GPT2Block(nn.Module): self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPT2Attention( - layer_id, config, quant_config, prefix=f"{prefix}.attn" + layer_id, config, quant_config, prefix=add_prefix("attn", prefix) ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP( @@ -153,7 +154,7 @@ class GPT2Block(nn.Module): config, act_layer=act_layer, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=add_prefix("mlp", prefix), ) def forward( @@ -196,7 +197,12 @@ class GPT2Model(nn.Module): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList( [ - GPT2Block(i, config, quant_config=quant_config) + GPT2Block( + i, + config, + quant_config=quant_config, + prefix=add_prefix(f"h.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -227,11 +233,14 @@ class GPT2LMHeadModel(nn.Module): self, config: GPT2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(config, quant_config, prefix="transformer") + self.transformer = GPT2Model( + config, quant_config, prefix=add_prefix("transformer", prefix) + ) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 0d705fb41..631da1298 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class GPTBigCodeAttention(nn.Module): @@ -44,6 +45,7 @@ class GPTBigCodeAttention(nn.Module): layer_id: int, config: GPTBigCodeConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -69,6 +71,7 @@ class GPTBigCodeAttention(nn.Module): total_num_kv_heads, bias=True, quant_config=quant_config, + prefix=add_prefix("c_attn", prefix), ) self.c_proj = RowParallelLinear( @@ -76,6 +79,7 @@ class GPTBigCodeAttention(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, + prefix=add_prefix("c_proj", prefix), ) self.attn = RadixAttention( self.num_heads, @@ -83,6 +87,7 @@ class GPTBigCodeAttention(nn.Module): scaling=self.scale, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -111,6 +116,7 @@ class GPTBigMLP(nn.Module): intermediate_size: int, config: GPTBigCodeConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -119,12 +125,14 @@ class GPTBigMLP(nn.Module): intermediate_size, bias=True, quant_config=quant_config, + prefix=add_prefix("c_fc", prefix), ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, quant_config=quant_config, + prefix=add_prefix("c_proj", prefix), ) self.act = get_act_fn( config.activation_function, quant_config, intermediate_size @@ -144,15 +152,20 @@ class GPTBigCodeBlock(nn.Module): layer_id: int, config: GPTBigCodeConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(layer_id, config, quant_config) + self.attn = GPTBigCodeAttention( + layer_id, config, quant_config, prefix=add_prefix("attn", prefix) + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config, quant_config) + self.mlp = GPTBigMLP( + inner_dim, config, quant_config, prefix=add_prefix("mlp", prefix) + ) def forward( self, @@ -181,6 +194,7 @@ class GPTBigCodeModel(nn.Module): self, config: GPTBigCodeConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -190,12 +204,17 @@ class GPTBigCodeModel(nn.Module): lora_vocab = 0 self.vocab_size = config.vocab_size + lora_vocab self.wte = VocabParallelEmbedding( - self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size + self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size, + prefix=add_prefix("wte", prefix), ) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList( [ - GPTBigCodeBlock(i, config, quant_config) + GPTBigCodeBlock( + i, config, quant_config, prefix=add_prefix(f"h.{i}", prefix) + ) for i in range(config.num_hidden_layers) ] ) @@ -235,13 +254,16 @@ class GPTBigCodeForCausalLM(nn.Module): self, config: GPTBigCodeConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, quant_config) + self.transformer = GPTBigCodeModel( + config, quant_config, prefix=add_prefix("transformer", prefix) + ) self.lm_head = self.transformer.wte self.unpadded_vocab_size = config.vocab_size self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index 255f23227..086a8fb82 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -62,14 +63,14 @@ class GraniteMLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.down_proj", + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -133,14 +134,14 @@ class GraniteAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.o_proj", + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -157,6 +158,7 @@ class GraniteAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -205,14 +207,14 @@ class GraniteDecoderLayer(nn.Module): rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, - prefix=f"{prefix}.self_attn", + prefix=add_prefix("self_attn", prefix), ) self.mlp = GraniteMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -252,6 +254,7 @@ class GraniteModel(nn.Module): self, config: GraniteConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -263,7 +266,10 @@ class GraniteModel(nn.Module): self.layers = nn.ModuleList( [ GraniteDecoderLayer( - config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -300,17 +306,23 @@ class GraniteForCausalLM(nn.Module): self, config: GraniteConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = GraniteModel(config, quant_config=quant_config) + self.model = GraniteModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) # If tie_word_embeddings == True, then input and output embeddings are # the same tensor. Enforce during object creation so that weights will # load correctly even if the LM head weights don't have a separate entry # in the state dict. self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) if self.config.tie_word_embeddings: self.lm_head.tie_weights(self.model.embed_tokens) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 269d32eaa..ff56bcef8 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -47,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class Grok1MLP(nn.Module): @@ -65,7 +66,7 @@ class Grok1MLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", + prefix=add_prefix("gate_up_proj", prefix), use_presharded_weights=use_presharded_weights, ) self.down_proj = RowParallelLinear( @@ -73,7 +74,7 @@ class Grok1MLP(nn.Module): hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.down_proj", + prefix=add_prefix("down_proj", prefix), reduce_results=reduce_results, use_presharded_weights=use_presharded_weights, ) @@ -107,6 +108,7 @@ class Grok1MoE(nn.Module): tp_size: Optional[int] = None, reduce_results=True, use_presharded_weights: bool = False, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -118,6 +120,7 @@ class Grok1MoE(nn.Module): bias=False, params_dtype=params_dtype, quant_config=None, + prefix=add_prefix("gate", prefix), ) self.router_logit_softcapping = getattr( @@ -135,6 +138,7 @@ class Grok1MoE(nn.Module): tp_size=tp_size, activation="gelu", use_presharded_weights=use_presharded_weights, + prefix=add_prefix("experts", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -163,6 +167,7 @@ class Grok1Attention(nn.Module): rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -195,6 +200,7 @@ class Grok1Attention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -202,6 +208,7 @@ class Grok1Attention(nn.Module): bias=False, quant_config=quant_config, reduce_results=reduce_results, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -220,6 +227,7 @@ class Grok1Attention(nn.Module): num_kv_heads=self.num_kv_heads, layer_id=layer_id, logit_cap=logit_cap, + prefix=add_prefix("attn", prefix), ) def forward( @@ -243,6 +251,7 @@ class Grok1DecoderLayer(nn.Module): layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, use_presharded_weights: bool = False, + prefix: str = "", ) -> None: super().__init__() self.num_experts = config.num_local_experts @@ -259,6 +268,7 @@ class Grok1DecoderLayer(nn.Module): layer_id=layer_id, rope_theta=rope_theta, quant_config=quant_config, + prefix=add_prefix("attn", prefix), ) self.block_sparse_moe = Grok1MoE( config=config, @@ -273,6 +283,7 @@ class Grok1DecoderLayer(nn.Module): quant_config=quant_config, reduce_results=True, use_presharded_weights=use_presharded_weights, + prefix=add_prefix("block_sparse_moe", prefix), ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -311,6 +322,7 @@ class Grok1Model(nn.Module): config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, use_presharded_weights: bool = False, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -320,6 +332,7 @@ class Grok1Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ @@ -328,6 +341,7 @@ class Grok1Model(nn.Module): i, quant_config=quant_config, use_presharded_weights=use_presharded_weights, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -359,6 +373,7 @@ class Grok1ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -377,8 +392,11 @@ class Grok1ForCausalLM(nn.Module): config, quant_config=quant_config, use_presharded_weights=self.use_presharded_weights, + prefix=add_prefix("model", prefix), + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) ) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) def forward( diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index ce8f9a3cf..fe39dd1a4 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class InternLM2MLP(nn.Module): @@ -47,13 +48,22 @@ class InternLM2MLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.w2 = RowParallelLinear( - intermediate_size, hidden_size, bias=False, quant_config=quant_config + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("w2", prefix), ) if hidden_act != "silu": raise ValueError( @@ -80,6 +90,7 @@ class InternLM2Attention(nn.Module): max_position_embeddings: int = 8192, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -111,12 +122,14 @@ class InternLM2Attention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("wqkv", prefix), ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("wo", prefix), ) self.rotary_emb = get_rope( @@ -127,7 +140,12 @@ class InternLM2Attention(nn.Module): rope_scaling=rope_scaling, ) self.attn = RadixAttention( - self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id + self.num_heads, + self.head_dim, + self.scaling, + self.num_kv_heads, + layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -150,6 +168,7 @@ class InternLMDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -165,12 +184,14 @@ class InternLMDecoderLayer(nn.Module): max_position_embeddings=max_position_embeddings, layer_id=layer_id, quant_config=quant_config, + prefix=add_prefix("attention", prefix), ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("feed_forward", prefix), ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -205,6 +226,7 @@ class InternLM2Model(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -213,10 +235,13 @@ class InternLM2Model(nn.Module): self.tok_embeddings = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("tok_embeddings", prefix), ) self.layers = nn.ModuleList( [ - InternLMDecoderLayer(config, i, quant_config) + InternLMDecoderLayer( + config, i, quant_config, prefix=add_prefix(f"layers.{i}", prefix) + ) for i in range(config.num_hidden_layers) ] ) @@ -251,12 +276,17 @@ class InternLM2ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = InternLM2Model(config, quant_config) - self.output = ParallelLMHead(config.vocab_size, config.hidden_size) + self.model = InternLM2Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.output = ParallelLMHead( + config.vocab_size, config.hidden_size, prefix=add_prefix("output", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/internlm2_reward.py b/python/sglang/srt/models/internlm2_reward.py index d5fe9c059..68be8d001 100644 --- a/python/sglang/srt/models/internlm2_reward.py +++ b/python/sglang/srt/models/internlm2_reward.py @@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model +from sglang.srt.utils import add_prefix class InternLM2ForRewardModel(nn.Module): @@ -29,12 +30,15 @@ class InternLM2ForRewardModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config self.vocab_size = config.vocab_size - self.model = InternLM2Model(config, quant_config) + self.model = InternLM2Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) self.v_head = nn.Linear(config.hidden_size, 1, bias=False) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 27b4277cf..4127bfcdf 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -49,7 +49,7 @@ from sglang.srt.model_loader.weight_utils import ( kv_cache_scales_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.utils import make_layers +from sglang.srt.utils import add_prefix, make_layers from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -70,14 +70,14 @@ class LlamaMLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.down_proj", + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -142,14 +142,14 @@ class LlamaAttention(nn.Module): self.total_num_kv_heads, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.o_proj", + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -166,6 +166,7 @@ class LlamaAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -218,7 +219,7 @@ class LlamaDecoderLayer(nn.Module): rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, - prefix=f"{prefix}.self_attn", + prefix=add_prefix("self_attn", prefix), bias=attention_bias, ) self.mlp = LlamaMLP( @@ -226,7 +227,7 @@ class LlamaDecoderLayer(nn.Module): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -263,6 +264,7 @@ class LlamaModel(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -272,6 +274,7 @@ class LlamaModel(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = make_layers( config.num_hidden_layers, @@ -358,18 +361,24 @@ class LlamaForCausalLM(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = LlamaModel(config, quant_config=quant_config) + self.model = LlamaModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) # Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.1 8B Instruct set tie_word_embeddings to False if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 75e8af9af..8387d2030 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel +from sglang.srt.utils import add_prefix class LlamaForClassification(nn.Module): @@ -30,11 +31,14 @@ class LlamaForClassification(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = LlamaModel(config, quant_config=quant_config) + self.model = LlamaModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) self.classification_head = nn.Linear( config.hidden_size, config.classification_out_size, bias=False diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index a9dbe8275..769ee6736 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +from sglang.srt.utils import add_prefix + # Adapted from # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" @@ -55,6 +57,7 @@ class LlamaModel(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -62,11 +65,15 @@ class LlamaModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ LlamaDecoderLayer( - config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -106,24 +113,26 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: nn.Module.__init__(self) self.config = config self.quant_config = quant_config - self.model = LlamaModel(config, quant_config=quant_config) + self.model = LlamaModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) # Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.1 8B Instruct set tie_word_embeddings to False if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - if hasattr(config, "hot_vocab_size"): - self.lm_head = ParallelLMHead( - config.hot_vocab_size, config.hidden_size, quant_config=quant_config - ) - else: - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + self.lm_head = ParallelLMHead( + getattr(config, "hot_vocab_size", config.vocab_size), + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index 34b316dda..ba448f7fc 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -8,6 +8,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaModel +from sglang.srt.utils import add_prefix class LlamaEmbeddingModel(nn.Module): @@ -15,9 +16,12 @@ class LlamaEmbeddingModel(nn.Module): self, config: LlamaConfig, quant_config=None, + prefix: str = "", ) -> None: super().__init__() - self.model = LlamaModel(config, quant_config=quant_config) + self.model = LlamaModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @torch.no_grad() diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index 6550ee411..2f78dfa1b 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel +from sglang.srt.utils import add_prefix class LlamaForSequenceClassification(nn.Module): @@ -29,12 +30,15 @@ class LlamaForSequenceClassification(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config self.num_labels = config.num_labels - self.model = LlamaModel(config, quant_config=quant_config) + self.model = LlamaModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) @@ -82,8 +86,9 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: - super().__init__(config, quant_config) + super().__init__(config, quant_config, prefix=prefix) self.weights = self.Weights(config.hidden_size, self.num_labels) @torch.no_grad() diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index a1f06e186..60d40c39a 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.utils import add_prefix class LlavaBaseForCausalLM(nn.Module): @@ -475,6 +476,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -484,7 +486,11 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): self.config.text_config.hidden_size = config.hidden_size self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = LlamaForCausalLM(config, quant_config=quant_config) + self.language_model = LlamaForCausalLM( + config, + quant_config=quant_config, + prefix=add_prefix("language_model", prefix), + ) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size, dtype=torch.float16) @@ -496,6 +502,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -516,7 +523,11 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): self.config.image_token_index = 151646 self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config) + self.language_model = Qwen2ForCausalLM( + config, + quant_config=quant_config, + prefix=add_prefix("language_model", prefix), + ) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size, dtype=torch.float16) @@ -528,6 +539,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -548,7 +560,11 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): self.config.image_token_index = 32000 self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = MistralForCausalLM(config, quant_config=quant_config) + self.language_model = MistralForCausalLM( + config, + quant_config=quant_config, + prefix=add_prefix("language_model", prefix), + ) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size, dtype=torch.float16) diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 37b5a7882..79bcf2329 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,6 +26,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM +from sglang.srt.utils import add_prefix class LlavaVidForCausalLM(nn.Module): @@ -33,6 +34,7 @@ class LlavaVidForCausalLM(nn.Module): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -44,7 +46,11 @@ class LlavaVidForCausalLM(nn.Module): self.resampler = nn.AvgPool2d( kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride ) - self.language_model = LlamaForCausalLM(config, quant_config=quant_config) + self.language_model = LlamaForCausalLM( + config, + quant_config=quant_config, + prefix=add_prefix("language_model", prefix), + ) self.num_frames = getattr(self.config, "num_frames", 16) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.language_model.model.image_newline = nn.Parameter( diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 6f8b500a4..f7133bcce 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -37,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class MiniCPMMLP(nn.Module): @@ -46,6 +47,7 @@ class MiniCPMMLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -53,12 +55,14 @@ class MiniCPMMLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -85,6 +89,7 @@ class MiniCPMAttention(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -116,12 +121,14 @@ class MiniCPMAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -139,6 +146,7 @@ class MiniCPMAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -164,6 +172,7 @@ class MiniCPMDecoderLayer(nn.Module): config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -180,12 +189,14 @@ class MiniCPMDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) self.mlp = MiniCPMMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -227,6 +238,7 @@ class MiniCPMModel(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -236,10 +248,16 @@ class MiniCPMModel(nn.Module): self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ - MiniCPMDecoderLayer(config, i, quant_config=quant_config) + MiniCPMDecoderLayer( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -275,19 +293,23 @@ class MiniCPMForCausalLM(nn.Module): self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config - self.model = MiniCPMModel(config, quant_config=quant_config) + self.model = MiniCPMModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) if not self.config.tie_word_embeddings: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + prefix=add_prefix("lm_head", prefix), ) self.scale_width = self.config.hidden_size / self.config.dim_model_base diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index f7b331bab..f1c08c5fe 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -40,7 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import add_prefix, is_cuda_available if is_cuda_available(): from sgl_kernel import bmm_fp8 @@ -53,6 +53,7 @@ class MiniCPM3MLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -60,12 +61,14 @@ class MiniCPM3MLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -107,6 +110,7 @@ class MiniCPM3Attention(nn.Module): max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, layer_id=None, + prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id @@ -131,6 +135,7 @@ class MiniCPM3Attention(nn.Module): self.q_lora_rank, bias=False, quant_config=quant_config, + prefix=add_prefix("q_a_proj", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -138,6 +143,7 @@ class MiniCPM3Attention(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_b_proj", prefix), ) else: self.q_proj = ColumnParallelLinear( @@ -145,6 +151,7 @@ class MiniCPM3Attention(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_proj", prefix), ) self.kv_a_proj_with_mqa = ReplicatedLinear( @@ -152,6 +159,7 @@ class MiniCPM3Attention(nn.Module): self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( @@ -159,6 +167,7 @@ class MiniCPM3Attention(nn.Module): self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, + prefix=add_prefix("kv_b_proj", prefix), ) # O projection. self.o_proj = RowParallelLinear( @@ -166,6 +175,7 @@ class MiniCPM3Attention(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( qk_rope_head_dim, @@ -182,6 +192,7 @@ class MiniCPM3Attention(nn.Module): self.scaling, num_kv_heads=self.num_local_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -250,6 +261,7 @@ class MiniCPM3AttentionMLA(nn.Module): max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, layer_id=None, + prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id @@ -274,6 +286,7 @@ class MiniCPM3AttentionMLA(nn.Module): self.q_lora_rank, bias=False, quant_config=quant_config, + prefix=add_prefix("q_a_proj", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -281,6 +294,7 @@ class MiniCPM3AttentionMLA(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_b_proj", prefix), ) else: self.q_proj = ColumnParallelLinear( @@ -288,6 +302,7 @@ class MiniCPM3AttentionMLA(nn.Module): self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("q_proj", prefix), ) self.kv_a_proj_with_mqa = ReplicatedLinear( @@ -295,6 +310,7 @@ class MiniCPM3AttentionMLA(nn.Module): self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, + prefix=add_prefix("kv_a_proj_with_mqa", prefix), ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( @@ -302,6 +318,7 @@ class MiniCPM3AttentionMLA(nn.Module): self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, + prefix=add_prefix("kv_b_proj", prefix), ) # O projection. self.o_proj = RowParallelLinear( @@ -309,6 +326,7 @@ class MiniCPM3AttentionMLA(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( qk_rope_head_dim, @@ -325,6 +343,7 @@ class MiniCPM3AttentionMLA(nn.Module): num_kv_heads=1, layer_id=layer_id, v_head_dim=self.kv_lora_rank, + prefix=add_prefix("attn", prefix), ) self.w_kc = None @@ -405,6 +424,7 @@ class MiniCPM3DecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -429,6 +449,7 @@ class MiniCPM3DecoderLayer(nn.Module): max_position_embeddings=max_position_embeddings, quant_config=quant_config, layer_id=layer_id, + prefix=add_prefix("self_attn", prefix), ) else: self.self_attn = MiniCPM3Attention( @@ -447,12 +468,14 @@ class MiniCPM3DecoderLayer(nn.Module): max_position_embeddings=max_position_embeddings, quant_config=quant_config, layer_id=layer_id, + prefix=add_prefix("self_attn", prefix), ) self.mlp = MiniCPM3MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -494,6 +517,7 @@ class MiniCPM3Model(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -503,10 +527,16 @@ class MiniCPM3Model(nn.Module): self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ - MiniCPM3DecoderLayer(config, i, quant_config=quant_config) + MiniCPM3DecoderLayer( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -542,19 +572,23 @@ class MiniCPM3ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config - self.model = MiniCPM3Model(config, quant_config=quant_config) + self.model = MiniCPM3Model( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) if not self.config.tie_word_embeddings: self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + prefix=add_prefix("lm_head", prefix), ) self.scale_width = self.config.hidden_size / self.config.dim_model_base diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 7b02b4ced..7905c808b 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM +from sglang.srt.utils import add_prefix RawImageType = Union[Image.Image, torch.Tensor] @@ -158,14 +159,14 @@ class Idefics2VisionMLP(nn.Module): config.intermediate_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.fc1", + prefix=add_prefix("fc1", prefix), ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.fc2", + prefix=add_prefix("fc2", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -199,10 +200,14 @@ class Idefics2EncoderLayer(nn.Module): use_context_forward=False, use_full_precision_softmax=True, flatten_batch=False, - prefix=f"{prefix}.self_attn", + prefix=add_prefix("self_attn", prefix), ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Idefics2VisionMLP(config, quant_config=quant_config) + self.mlp = Idefics2VisionMLP( + config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( @@ -242,6 +247,7 @@ class Idefics2Encoder(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -251,8 +257,9 @@ class Idefics2Encoder(nn.Module): Idefics2EncoderLayer( config, quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), ) - for _ in range(config.num_hidden_layers) + for i in range(config.num_hidden_layers) ] ) @@ -379,13 +386,18 @@ class Idefics2VisionTransformer(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() embed_dim = config.hidden_size self.config = config self.embeddings = Idefics2VisionEmbeddings(config) - self.encoder = Idefics2Encoder(config=config, quant_config=quant_config) + self.encoder = Idefics2Encoder( + config=config, + quant_config=quant_config, + prefix=add_prefix("encoder", prefix), + ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def get_input_embeddings(self): @@ -503,7 +515,7 @@ class BaseResampler(nn.Module): embed_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_proj", + prefix=add_prefix("kv_proj", prefix), ) else: # Maintain the same return value with ReplicatedLinear.forward @@ -660,6 +672,7 @@ class MiniCPMVBaseModel(nn.Module): *, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but @@ -669,8 +682,12 @@ class MiniCPMVBaseModel(nn.Module): self.config = config self.version = get_version_by_config(self.config) - self.llm = self.init_llm(config=config, quant_config=quant_config) - self.vpm = self.init_vision_module(config, quant_config) + self.llm = self.init_llm( + config=config, quant_config=quant_config, prefix=add_prefix("llm", prefix) + ) + self.vpm = self.init_vision_module( + config, quant_config, add_prefix("vpm", prefix) + ) self.vision_dim = ( self.vpm.embed_dim if self.version == (2, 0) @@ -679,7 +696,10 @@ class MiniCPMVBaseModel(nn.Module): self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler( - self.embed_dim, self.vision_dim, quant_config=quant_config + self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix=add_prefix("resampler", prefix), ) self.logits_processor = LogitsProcessor(config) @@ -937,6 +957,7 @@ class MiniCPMVBaseModel(nn.Module): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -944,6 +965,7 @@ class MiniCPMVBaseModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -952,6 +974,7 @@ class MiniCPMVBaseModel(nn.Module): embed_dim: int, vision_dim: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: raise NotImplementedError @@ -1011,24 +1034,27 @@ class MiniCPMV2_6(MiniCPMVBaseModel): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): - super().__init__(config=config, quant_config=quant_config) + super().__init__(config=config, quant_config=quant_config, prefix=prefix) assert self.version == (2, 6) def init_llm( self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: - return Qwen2ForCausalLM(config=config, quant_config=quant_config) + return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix) def init_vision_module( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> nn.Module: model = Idefics2VisionTransformer( - config=config.vision_config, quant_config=quant_config + config=config.vision_config, quant_config=quant_config, prefix=prefix ) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] @@ -1042,6 +1068,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel): embed_dim: int, vision_dim: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. @@ -1051,6 +1078,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel): num_heads=embed_dim // 128, kv_dim=vision_dim, quant_config=quant_config, + prefix=prefix, ) return resampler.to(device="cuda", dtype=torch.get_default_dtype()) @@ -1207,6 +1235,7 @@ class MiniCPMV: self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -1221,7 +1250,9 @@ class MiniCPMV: raise ValueError("Currently, MiniCPMV only supports versions 2.6") try: - minicpmv = instance_class(config=config, quant_config=quant_config) + minicpmv = instance_class( + config=config, quant_config=quant_config, prefix=prefix + ) self.minicpmv = minicpmv except Exception as e: print(f"Failed to instantiate MiniCPMV: {e}") diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 4ea734836..058f96fdd 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class MixtralMoE(nn.Module): @@ -78,7 +79,7 @@ class MixtralMoE(nn.Module): bias=False, params_dtype=params_dtype, quant_config=None, - prefix=f"{prefix}.gate", + prefix=add_prefix("gate", prefix), ) MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE self.experts = MoEImpl( @@ -90,7 +91,7 @@ class MixtralMoE(nn.Module): renormalize=True, quant_config=quant_config, tp_size=tp_size, - prefix=f"{prefix}.experts", + prefix=add_prefix("experts", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -146,14 +147,14 @@ class MixtralAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.o_proj", + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -168,6 +169,7 @@ class MixtralAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -204,7 +206,7 @@ class MixtralDecoderLayer(nn.Module): layer_id=layer_id, rope_theta=rope_theta, quant_config=quant_config, - prefix=f"{prefix}.self_attn", + prefix=add_prefix("self_attn", prefix), ) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, @@ -212,7 +214,7 @@ class MixtralDecoderLayer(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe", + prefix=add_prefix("block_sparse_moe", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -258,11 +260,15 @@ class MixtralModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ MixtralDecoderLayer( - config, i, quant_config=quant_config, prefix=f"{prefix}.layers" + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -296,12 +302,17 @@ class MixtralForCausalLM(nn.Module): self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config=quant_config, prefix="model") - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.model = MixtralModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) + ) self.logits_processor = LogitsProcessor(config) def forward( diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 244dc7df2..c3ba17bc9 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class MixtralMLP(nn.Module): @@ -54,6 +55,7 @@ class MixtralMLP(nn.Module): hidden_size: int, intermediate_size: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.num_experts = num_experts @@ -61,13 +63,25 @@ class MixtralMLP(nn.Module): self.hidden_dim = hidden_size self.w1 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("w1", prefix), ) self.w2 = ReplicatedLinear( - self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config + self.ffn_dim, + self.hidden_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("w2", prefix), ) self.w3 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config + self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("w3", prefix), ) # TODO: Use vllm's SiluAndMul @@ -87,6 +101,7 @@ class MixtralMoE(nn.Module): self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -114,6 +129,7 @@ class MixtralMoE(nn.Module): config.hidden_size, config.intermediate_size, quant_config=quant_config, + prefix=add_prefix(f"experts.{idx}", prefix), ) if idx in self.expert_indicies else None @@ -122,7 +138,11 @@ class MixtralMoE(nn.Module): ] ) self.gate = ReplicatedLinear( - config.hidden_size, self.num_total_experts, bias=False, quant_config=None + config.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + prefix=add_prefix("gate", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -159,6 +179,7 @@ class MixtralAttention(nn.Module): max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -189,12 +210,14 @@ class MixtralAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -209,6 +232,7 @@ class MixtralAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -231,6 +255,7 @@ class MixtralDecoderLayer(nn.Module): config: MixtralConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -244,8 +269,13 @@ class MixtralDecoderLayer(nn.Module): layer_id=layer_id, rope_theta=rope_theta, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.block_sparse_moe = MixtralMoE( + config=config, + quant_config=quant_config, + prefix=add_prefix("block_sparse_moe", prefix), ) - self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -281,6 +311,7 @@ class MixtralModel(nn.Module): self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -289,10 +320,16 @@ class MixtralModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ - MixtralDecoderLayer(config, i, quant_config=quant_config) + MixtralDecoderLayer( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -324,12 +361,17 @@ class QuantMixtralForCausalLM(nn.Module): self, config: MixtralConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.model = MixtralModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 05069edb6..dd52ae6fd 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -36,6 +36,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP +from sglang.srt.utils import add_prefix class ColumnParallelConv2dPatch(torch.nn.Module): @@ -147,7 +148,12 @@ class MllamaPrecomputedPositionEmbedding(nn.Module): class MllamaVisionMLP(nn.Module): - def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) @@ -156,12 +162,14 @@ class MllamaVisionMLP(nn.Module): config.intermediate_size, bias=True, quant_config=quant_config, + prefix=add_prefix("fc1", prefix), ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=True, quant_config=quant_config, + prefix=add_prefix("fc2", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -174,7 +182,10 @@ class MllamaVisionMLP(nn.Module): class MllamaVisionEncoderLayer(nn.Module): def __init__( - self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False + self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = False, + prefix: str = "", ): super().__init__() @@ -193,8 +204,9 @@ class MllamaVisionEncoderLayer(nn.Module): use_context_forward=False, use_full_precision_softmax=False, flatten_batch=False, + prefix=add_prefix("self_attn", prefix), ) - self.mlp = MllamaVisionMLP(config) + self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix)) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) self.post_attention_layernorm = nn.LayerNorm( @@ -235,11 +247,17 @@ class MllamaVisionEncoder(nn.Module): num_layers=32, is_gated=False, output_hidden_states=None, + prefix: str = "", ): super().__init__() self.config = config self.layers = nn.ModuleList( - [MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)] + [ + MllamaVisionEncoderLayer( + config, is_gated, prefix=add_prefix(f"layers.{i}", prefix) + ) + for i in range(num_layers) + ] ) self.output_hidden_states = output_hidden_states or [] @@ -265,7 +283,7 @@ class MllamaVisionEncoder(nn.Module): class MllamaVisionModel(nn.Module): - def __init__(self, config: config_mllama.MllamaVisionConfig): + def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""): super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size @@ -305,9 +323,13 @@ class MllamaVisionModel(nn.Module): config.num_hidden_layers, is_gated=False, output_hidden_states=config.intermediate_layers_indices, + prefix=add_prefix("transformer", prefix), ) self.global_transformer = MllamaVisionEncoder( - config, config.num_global_layers, is_gated=True + config, + config.num_global_layers, + is_gated=True, + prefix=add_prefix("global_transformer", prefix), ) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: @@ -464,6 +486,7 @@ class MllamaTextCrossAttention(nn.Module): config: Optional[config_mllama.MllamaTextConfig] = None, layer_id: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -489,6 +512,7 @@ class MllamaTextCrossAttention(nn.Module): self.num_key_value_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.num_heads * self.head_dim, @@ -496,6 +520,7 @@ class MllamaTextCrossAttention(nn.Module): bias=False, input_is_parallel=True, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # use huggingface's instead @@ -510,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module): self.num_local_key_value_heads, layer_id=layer_id, is_cross_attention=True, + prefix=add_prefix("attn", prefix), ) def forward( @@ -551,6 +577,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): config: config_mllama.MllamaTextConfig, layer_id: int, quant_config: Optional[QuantizationConfig], + prefix: str = "", ) -> None: super().__init__() self.layer_id = layer_id @@ -558,6 +585,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): config=config, layer_id=layer_id, quant_config=quant_config, + prefix=add_prefix("cross_attn", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -568,6 +596,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -610,12 +639,15 @@ class MllamaTextModel(nn.Module): self, config: config_mllama.MllamaTextConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ): super().__init__() self.padding_id = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size + 8, config.hidden_size + config.vocab_size + 8, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.cross_attention_layers = config.cross_attention_layers @@ -624,14 +656,20 @@ class MllamaTextModel(nn.Module): if layer_id in self.cross_attention_layers: layers.append( MllamaCrossAttentionDecoderLayer( - config, layer_id, quant_config=quant_config + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_id}", prefix), ) ) else: # TODO: force LlamaDecoderLayer to config.attention_bias=False layers.append( LlamaDecoderLayer( - config, quant_config=quant_config, layer_id=layer_id + config, + quant_config=quant_config, + layer_id=layer_id, + prefix=add_prefix(f"layers.{layer_id}", prefix), ) ) @@ -687,16 +725,20 @@ class MllamaForCausalLM(nn.Module): self, config: config_mllama.MllamaTextConfig, quant_config: Optional[QuantizationConfig], + prefix: str = "", ): super().__init__() self.vocab_size = config.vocab_size - self.model = MllamaTextModel(config, quant_config) + self.model = MllamaTextModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) def forward( @@ -726,6 +768,7 @@ class MllamaForConditionalGeneration(nn.Module): self, config: config_mllama.MllamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.vocab_size = config.text_config.vocab_size @@ -737,10 +780,13 @@ class MllamaForConditionalGeneration(nn.Module): ) self.image_size = config.vision_config.image_size - self.vision_model = MllamaVisionModel(config.vision_config) + self.vision_model = MllamaVisionModel( + config.vision_config, prefix=add_prefix("vision_model", prefix) + ) self.language_model = MllamaForCausalLM( config.text_config, quant_config=quant_config, + prefix=add_prefix("language_model", prefix), ) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 9f118ea6b..686cb01ac 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -38,7 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import make_layers +from sglang.srt.utils import add_prefix, make_layers class OlmoAttention(nn.Module): @@ -53,6 +53,7 @@ class OlmoAttention(nn.Module): config: OlmoConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -75,6 +76,7 @@ class OlmoAttention(nn.Module): self.head_dim, self.total_num_heads, bias=config.attention_bias, + prefix=add_prefix("qkv_proj", prefix), ) # Rotary embeddings. @@ -91,6 +93,7 @@ class OlmoAttention(nn.Module): self.scaling, num_kv_heads=self.num_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) # Attention output projection. @@ -98,6 +101,7 @@ class OlmoAttention(nn.Module): self.hidden_size, self.hidden_size, bias=config.attention_bias, + prefix=add_prefix("o_proj", prefix), ) def forward( @@ -127,6 +131,7 @@ class OlmoMLP(nn.Module): self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -139,6 +144,7 @@ class OlmoMLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) # Activation function. @@ -150,6 +156,7 @@ class OlmoMLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) def forward( @@ -174,13 +181,23 @@ class OlmoDecoderLayer(nn.Module): config: OlmoConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, layer_id, quant_config) + self.self_attn = OlmoAttention( + config, + layer_id, + quant_config, + prefix=add_prefix("self_attn", prefix), + ) # MLP block. - self.mlp = OlmoMLP(config, quant_config) + self.mlp = OlmoMLP( + config, + quant_config, + prefix=add_prefix("mlp", prefix), + ) # LayerNorm self.input_layernorm = nn.LayerNorm( @@ -213,13 +230,18 @@ class OlmoDecoderLayer(nn.Module): class OlmoModel(nn.Module): def __init__( - self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = make_layers( config.num_hidden_layers, @@ -227,7 +249,9 @@ class OlmoModel(nn.Module): layer_id=idx, config=config, quant_config=quant_config, + prefix=prefix, ), + prefix=add_prefix("layers", prefix), ) self.norm = nn.LayerNorm( config.hidden_size, elementwise_affine=False, bias=False @@ -275,10 +299,11 @@ class OlmoForCausalLM(nn.Module): self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config - self.model = OlmoModel(config, quant_config) + self.model = OlmoModel(config, quant_config, prefix=add_prefix("model", prefix)) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -288,6 +313,7 @@ class OlmoForCausalLM(nn.Module): config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index a8af7bc1a..716ae99e4 100644 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import make_layers +from sglang.srt.utils import add_prefix, make_layers class Olmo2Attention(nn.Module): @@ -60,6 +60,7 @@ class Olmo2Attention(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module): self.head_dim, self.total_num_heads, bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.tp_rank = get_tensor_model_parallel_rank() @@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) # Attention output projection. @@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module): self.head_dim * self.total_num_heads, self.hidden_size, bias=config.attention_bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) def _apply_qk_norm( @@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module): [self.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) # Activation function. @@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module): self.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) def forward( @@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() # Attention block. - self.self_attn = Olmo2Attention(config, layer_id, quant_config) + self.self_attn = Olmo2Attention( + config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix) + ) # MLP block. - self.mlp = Olmo2MLP(config, quant_config) + self.mlp = Olmo2MLP(config, quant_config, prefix=add_prefix("mlp", prefix)) # RMSNorm self.post_attention_layernorm = RMSNorm( @@ -254,12 +266,15 @@ class Olmo2Model(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = make_layers( config.num_hidden_layers, @@ -267,7 +282,9 @@ class Olmo2Model(nn.Module): layer_id=idx, config=config, quant_config=quant_config, + prefix=prefix, ), + prefix=add_prefix("layers", prefix), ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config - self.model = Olmo2Model(config, quant_config) + self.model = Olmo2Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module): config.hidden_size, org_num_embeddings=config.vocab_size, quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 10b781d72..df3bd0dbf 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -41,7 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import make_layers, print_warning_once +from sglang.srt.utils import add_prefix, make_layers, print_warning_once class OlmoeMoE(nn.Module): @@ -69,7 +69,11 @@ class OlmoeMoE(nn.Module): # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( - hidden_size, num_experts, bias=False, quant_config=None + hidden_size, + num_experts, + bias=False, + quant_config=None, + prefix=add_prefix("gate", prefix), ) self.experts = FusedMoE( @@ -81,6 +85,7 @@ class OlmoeMoE(nn.Module): renormalize=False, quant_config=quant_config, tp_size=tp_size, + prefix=add_prefix("experts", prefix), ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -107,6 +112,7 @@ class OlmoeAttention(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 4096, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -138,6 +144,7 @@ class OlmoeAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.q_norm = RMSNorm(hidden_size, eps=1e-5) self.k_norm = RMSNorm(hidden_size, eps=1e-5) @@ -146,6 +153,7 @@ class OlmoeAttention(nn.Module): hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -162,6 +170,7 @@ class OlmoeAttention(nn.Module): self.scaling, layer_id=layer_id, num_kv_heads=self.num_kv_heads, + prefix=add_prefix("attn", prefix), ) def forward( @@ -186,6 +195,7 @@ class OlmoeDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -202,6 +212,7 @@ class OlmoeDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) self.mlp = OlmoeMoE( @@ -210,6 +221,7 @@ class OlmoeDecoderLayer(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) @@ -246,6 +258,7 @@ class OlmoeModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -254,6 +267,7 @@ class OlmoeModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = make_layers( config.num_hidden_layers, @@ -261,7 +275,9 @@ class OlmoeModel(nn.Module): config=config, quant_config=quant_config, layer_id=idx, + prefix=prefix, ), + prefix=add_prefix("layers", prefix), ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) @@ -294,13 +310,19 @@ class OlmoeForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = OlmoeModel(config, quant_config) + self.model = OlmoeModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index fa365b98c..af85b5966 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -24,7 +24,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import make_layers +from sglang.srt.utils import add_prefix, make_layers @torch.jit.script @@ -70,13 +70,14 @@ class Phi3SmallMLP(nn.Module): 2 * [self.intermediate_size], bias=True, quant_config=quant_config, - prefix=f"{prefix}.up_proj", + prefix=add_prefix("up_proj", prefix), ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=True, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) def forward(self, x): @@ -140,7 +141,7 @@ class Phi3SmallSelfAttention(nn.Module): self.num_key_value_heads, bias=True, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) self.dense = RowParallelLinear( @@ -148,7 +149,7 @@ class Phi3SmallSelfAttention(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.o_proj", + prefix=add_prefix("o_proj", prefix), ) if getattr(self.config, "rope_scaling", None) is not None: @@ -201,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module): self.scale, num_kv_heads=self.num_kv_heads_per_partion, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -234,13 +236,21 @@ class Phi3SmallDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Phi3SmallSelfAttention( - config, layer_id, quant_config=quant_config + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.mlp = Phi3SmallMLP( + config, + quant_config, + prefix=add_prefix("mlp", prefix), ) - self.mlp = Phi3SmallMLP(config, quant_config) self.input_layernorm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_epsilon @@ -284,15 +294,20 @@ class Phi3SmallModel(nn.Module): self.config = config self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.mup_embedding_multiplier = config.mup_embedding_multiplier self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Phi3SmallDecoderLayer( - config, int(prefix.split(".")[-1]), quant_config + config, + int(prefix.split(".")[-1]), + quant_config, + prefix=prefix, ), - prefix=f"{prefix}.layers", + prefix=add_prefix("layers", prefix), ) self.final_layernorm = nn.LayerNorm( @@ -335,6 +350,7 @@ class Phi3SmallForCausalLM(nn.Module): self, config: Phi3Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -344,7 +360,7 @@ class Phi3SmallForCausalLM(nn.Module): self.model = Phi3SmallModel( config=config, quant_config=quant_config, - prefix="model", + prefix=add_prefix("model", prefix), ) self.vocab_size = config.vocab_size self.mup_width_multiplier = config.mup_width_multiplier @@ -354,6 +370,7 @@ class Phi3SmallForCausalLM(nn.Module): org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 2c99da926..cd94a9103 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class QWenMLP(nn.Module): @@ -48,6 +49,7 @@ class QWenMLP(nn.Module): intermediate_size: int, hidden_act: str = "silu", quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -56,6 +58,7 @@ class QWenMLP(nn.Module): bias=False, gather_output=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.c_proj = RowParallelLinear( intermediate_size, @@ -63,6 +66,7 @@ class QWenMLP(nn.Module): bias=False, input_is_parallel=True, quant_config=quant_config, + prefix=add_prefix("c_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -88,6 +92,7 @@ class QWenAttention(nn.Module): rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -104,6 +109,7 @@ class QWenAttention(nn.Module): self.total_num_heads, bias=True, quant_config=quant_config, + prefix=add_prefix("c_attn", prefix), ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -111,6 +117,7 @@ class QWenAttention(nn.Module): bias=False, input_is_parallel=True, quant_config=quant_config, + prefix=add_prefix("c_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -126,6 +133,7 @@ class QWenAttention(nn.Module): self.scaling, num_kv_heads=self.num_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -148,6 +156,7 @@ class QWenBlock(nn.Module): config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -162,6 +171,7 @@ class QWenBlock(nn.Module): rope_scaling=rope_scaling, layer_id=layer_id, quant_config=quant_config, + prefix=add_prefix("attn", prefix), ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -170,6 +180,7 @@ class QWenBlock(nn.Module): config.hidden_size, config.intermediate_size // 2, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) def forward( @@ -201,6 +212,7 @@ class QWenModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -210,10 +222,16 @@ class QWenModel(nn.Module): self.wte = VocabParallelEmbedding( vocab_size, config.hidden_size, + prefix=add_prefix("wte", prefix), ) self.h = nn.ModuleList( [ - QWenBlock(config, i, quant_config=quant_config) + QWenBlock( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"h.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -242,12 +260,17 @@ class QWenLMHeadModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config - self.transformer = QWenModel(config, quant_config=quant_config) + self.transformer = QWenModel( + config, quant_config=quant_config, prefix=add_prefix("transformer", prefix) + ) vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead( + vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index d53d9561f..2100845e0 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -15,7 +15,7 @@ # Adapted from llama2.py # Modify details for the adaptation of Qwen2 model. """Inference-only Qwen2 model compatible with HuggingFace weights.""" - +from readline import add_history from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, ) -from sglang.srt.utils import make_layers +from sglang.srt.utils import add_prefix, make_layers Qwen2Config = None @@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module): intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 32768, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module): self.total_num_kv_heads, bias=True, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module): config: Qwen2Config, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -228,6 +238,7 @@ class Qwen2Model(nn.Module): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -237,6 +248,7 @@ class Qwen2Model(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = make_layers( config.num_hidden_layers, @@ -244,7 +256,9 @@ class Qwen2Model(nn.Module): layer_id=idx, config=config, quant_config=quant_config, + prefix=prefix, ), + prefix=add_prefix("layers", prefix), ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2Model(config, quant_config=quant_config) + self.model = Qwen2Model( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 4fdea2ec7..cee599a76 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs +from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -65,16 +66,29 @@ class Qwen2_5_VLMLP(nn.Module): bias: bool = True, hidden_act="silu", quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.gate_proj = ColumnParallelLinear( - in_features, hidden_features, bias=bias, quant_config=quant_config + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("gate_proj", prefix), ) self.up_proj = ColumnParallelLinear( - in_features, hidden_features, bias=bias, quant_config=quant_config + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("up_proj", prefix), ) self.down_proj = RowParallelLinear( - hidden_features, in_features, bias=bias, quant_config=quant_config + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) self.act = ACT2FN[hidden_act] @@ -98,6 +112,7 @@ class Qwen2_5_VisionBlock(nn.Module): norm_layer: Type[nn.Module] = None, attn_implementation: Optional[str] = "sdpa", quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() if norm_layer is None: @@ -123,9 +138,14 @@ class Qwen2_5_VisionBlock(nn.Module): use_full_precision_softmax=use_full_precision_softmax, flatten_batch=True, quant_config=quant_config, + prefix=add_prefix("attn", prefix), ) self.mlp = Qwen2_5_VLMLP( - dim, intermediate_dim, hidden_act=hidden_act, quant_config=quant_config + dim, + intermediate_dim, + hidden_act=hidden_act, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) def forward( @@ -178,6 +198,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module): context_dim: int, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -189,10 +210,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, + prefix=add_prefix("mlp.0", prefix), ), nn.GELU(), RowParallelLinear( - self.hidden_size, dim, bias=True, quant_config=quant_config + self.hidden_size, + dim, + bias=True, + quant_config=quant_config, + prefix=add_prefix("mlp.2", prefix), ), ] ) @@ -250,6 +276,7 @@ class Qwen2_5_VisionTransformer(nn.Module): vision_config: Qwen2_5_VLVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -286,8 +313,9 @@ class Qwen2_5_VisionTransformer(nn.Module): norm_layer=norm_layer, attn_implementation="sdpa", quant_config=quant_config, + prefix=add_prefix(f"blocks.{i}", prefix), ) - for _ in range(depth) + for i in range(depth) ] ) self.merger = Qwen2_5_VisionPatchMerger( @@ -295,6 +323,7 @@ class Qwen2_5_VisionTransformer(nn.Module): context_dim=hidden_size, spatial_merge_size=spatial_merge_size, quant_config=quant_config, + prefix=add_prefix("merger", prefix), ) def get_window_index(self, grid_thw): @@ -447,6 +476,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): self, config: Qwen2VLConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -457,15 +487,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): # NOTE: Qwen2-VL vision encoder does not support any # quantization method now. quant_config=None, + prefix=add_prefix("visual", prefix), ) - self.model = Qwen2Model(config, quant_config) + self.model = Qwen2Model( + config, + quant_config, + prefix=add_prefix("model", prefix), + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen2_eagle.py b/python/sglang/srt/models/qwen2_eagle.py index 12a4e6b3f..793d91560 100644 --- a/python/sglang/srt/models/qwen2_eagle.py +++ b/python/sglang/srt/models/qwen2_eagle.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +from sglang.srt.utils import add_prefix + # Adapted from # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" @@ -42,7 +44,7 @@ class Qwen2DecoderLayer(Qwen2DecoderLayer): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config, layer_id, quant_config) + super().__init__(config, layer_id, quant_config, prefix=prefix) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 @@ -56,6 +58,7 @@ class Qwen2Model(nn.Module): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -63,11 +66,15 @@ class Qwen2Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ Qwen2DecoderLayer( - config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -107,16 +114,22 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: nn.Module.__init__(self) self.config = config self.quant_config = quant_config - self.model = Qwen2Model(config, quant_config=quant_config) + self.model = Qwen2Model( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 6183f30da..b5ef472ce 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class Qwen2MoeMLP(nn.Module): @@ -56,10 +57,15 @@ class Qwen2MoeMLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, @@ -67,6 +73,7 @@ class Qwen2MoeMLP(nn.Module): bias=False, quant_config=quant_config, reduce_results=reduce_results, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -87,6 +94,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -105,10 +113,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module): reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, + prefix=add_prefix("experts", prefix), ) self.gate = ReplicatedLinear( - config.hidden_size, config.num_experts, bias=False, quant_config=None + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=add_prefix("gate", prefix), ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( @@ -117,6 +130,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=add_prefix("shared_expert", prefix), ) else: self.shared_expert = None @@ -157,6 +171,7 @@ class Qwen2MoeAttention(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -188,6 +203,7 @@ class Qwen2MoeAttention(nn.Module): self.total_num_kv_heads, bias=True, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( @@ -195,6 +211,7 @@ class Qwen2MoeAttention(nn.Module): hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -210,6 +227,7 @@ class Qwen2MoeAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -232,6 +250,7 @@ class Qwen2MoeDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -247,6 +266,7 @@ class Qwen2MoeDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) # Note: Qwen/Qwen2-57B-A14B-Instruct does not have @@ -257,13 +277,18 @@ class Qwen2MoeDecoderLayer(nn.Module): if (layer_id not in mlp_only_layers) and ( config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 ): - self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config) + self.mlp = Qwen2MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -300,6 +325,7 @@ class Qwen2MoeModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -308,10 +334,16 @@ class Qwen2MoeModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ - Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config) + Qwen2MoeDecoderLayer( + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_id}", prefix), + ) for layer_id in range(config.num_hidden_layers) ] ) @@ -346,13 +378,19 @@ class Qwen2MoeForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(config, quant_config) + self.model = Qwen2MoeModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen2_rm.py b/python/sglang/srt/models/qwen2_rm.py index 39ed15fa5..f5ed9eae2 100644 --- a/python/sglang/srt/models/qwen2_rm.py +++ b/python/sglang/srt/models/qwen2_rm.py @@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model +from sglang.srt.utils import add_prefix class Qwen2ForRewardModel(nn.Module): @@ -29,12 +30,15 @@ class Qwen2ForRewardModel(nn.Module): self, config: Qwen2Config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config self.num_labels = 1 - self.model = Qwen2Model(config, quant_config=quant_config) + self.model = Qwen2Model( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) self.score = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index d8e190deb..63b479113 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model +from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -91,14 +92,21 @@ class Qwen2VisionMLP(nn.Module): hidden_features: int = None, act_layer: Type[nn.Module] = QuickGELU, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.fc1 = ColumnParallelLinear( - in_features, hidden_features, quant_config=quant_config + in_features, + hidden_features, + quant_config=quant_config, + prefix=add_prefix("fc1", prefix), ) self.act = act_layer() self.fc2 = RowParallelLinear( - hidden_features, in_features, quant_config=quant_config + hidden_features, + in_features, + quant_config=quant_config, + prefix=add_prefix("fc2", prefix), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -119,6 +127,7 @@ class Qwen2VisionBlock(nn.Module): norm_layer: Type[nn.Module] = None, attn_implementation: Optional[str] = "sdpa", quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() if norm_layer is None: @@ -145,9 +154,14 @@ class Qwen2VisionBlock(nn.Module): use_full_precision_softmax=use_full_precision_softmax, flatten_batch=True, quant_config=quant_config, + prefix=add_prefix("attn", prefix), ) self.mlp = Qwen2VisionMLP( - dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config + dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) def forward( @@ -199,6 +213,7 @@ class Qwen2VisionPatchMerger(nn.Module): norm_layer: Type[nn.Module] = None, spatial_merge_size: int = 2, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -212,10 +227,15 @@ class Qwen2VisionPatchMerger(nn.Module): self.hidden_size, bias=True, quant_config=quant_config, + prefix=add_prefix("mlp.0", prefix), ), nn.GELU(), RowParallelLinear( - self.hidden_size, d_model, bias=True, quant_config=quant_config + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=add_prefix("mlp.2", prefix), ), ] ) @@ -273,6 +293,7 @@ class Qwen2VisionTransformer(nn.Module): vision_config: Qwen2VLVisionConfig, norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -307,8 +328,9 @@ class Qwen2VisionTransformer(nn.Module): norm_layer=norm_layer, attn_implementation="sdpa", quant_config=quant_config, + prefix=add_prefix(f"blocks.{i}", prefix), ) - for _ in range(depth) + for i in range(depth) ] ) self.merger = Qwen2VisionPatchMerger( @@ -316,6 +338,7 @@ class Qwen2VisionTransformer(nn.Module): context_dim=embed_dim, norm_layer=norm_layer, quant_config=quant_config, + prefix=add_prefix("merger", prefix), ) @property @@ -440,6 +463,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): self, config: Qwen2VLConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() @@ -450,15 +474,21 @@ class Qwen2VLForConditionalGeneration(nn.Module): # NOTE: Qwen2-VL vision encoder does not support any # quantization method now. quant_config=None, + prefix=add_prefix("visual", prefix), ) - self.model = Qwen2Model(config, quant_config) + self.model = Qwen2Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index c169dd6fb..45ac90c97 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class StablelmMLP(nn.Module): @@ -49,6 +50,7 @@ class StablelmMLP(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -59,12 +61,14 @@ class StablelmMLP(nn.Module): [config.intermediate_size] * 2, bias=False, quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), ) self.act_fn = SiluAndMul() @@ -81,6 +85,7 @@ class StablelmAttention(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -122,11 +127,15 @@ class StablelmAttention(nn.Module): self.total_num_heads, self.total_num_key_value_heads, self.qkv_bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( self.head_dim, @@ -140,6 +149,7 @@ class StablelmAttention(nn.Module): self.scaling, num_kv_heads=self.num_key_value_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -162,10 +172,15 @@ class StablelmDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() - self.self_attn = StablelmAttention(config, layer_id=layer_id) - self.mlp = StablelmMLP(config, quant_config=quant_config) + self.self_attn = StablelmAttention( + config, layer_id=layer_id, prefix=add_prefix("self_attn", prefix) + ) + self.mlp = StablelmMLP( + config, quant_config=quant_config, prefix=add_prefix("mlp", prefix) + ) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) @@ -200,15 +215,22 @@ class StableLMEpochModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ - StablelmDecoderLayer(config, i, quant_config=quant_config) + StablelmDecoderLayer( + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) for i in range(config.num_hidden_layers) ] ) @@ -242,12 +264,17 @@ class StableLmForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.model = StableLMEpochModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 0612e3e7d..1c99c52a8 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -64,6 +64,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() @@ -294,14 +295,14 @@ class LlamaDecoderLayer(nn.Module): rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, - prefix=f"{prefix}.self_attn", + prefix=add_prefix("self_attn", prefix), ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 7fd241823..2162f7a44 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class XverseMLP(nn.Module): @@ -57,14 +58,14 @@ class XverseMLP(nn.Module): [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.down_proj", + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -128,14 +129,14 @@ class XverseAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.o_proj", + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -152,6 +153,7 @@ class XverseAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -202,14 +204,14 @@ class XverseDecoderLayer(nn.Module): rope_is_neox_style=rope_is_neox_style, max_position_embeddings=max_position_embeddings, quant_config=quant_config, - prefix=f"{prefix}.self_attn", + prefix=add_prefix("self_attn", prefix), ) self.mlp = XverseMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -246,6 +248,7 @@ class XverseModel(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -254,11 +257,15 @@ class XverseModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ XverseDecoderLayer( - config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + config, + i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), ) for i in range(config.num_hidden_layers) ] @@ -295,12 +302,17 @@ class XverseForCausalLM(nn.Module): self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = XverseModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.model = XverseModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) + ) self.logits_processor = LogitsProcessor(config) @torch.no_grad() diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 218b96f9c..a7c79ec8c 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix class XverseMLP(nn.Module): @@ -54,10 +55,15 @@ class XverseMLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( intermediate_size, @@ -65,6 +71,7 @@ class XverseMLP(nn.Module): bias=False, quant_config=quant_config, reduce_results=reduce_results, + prefix=add_prefix("down_proj", prefix), ) if hidden_act != "silu": raise ValueError( @@ -86,6 +93,7 @@ class XverseMoE(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -107,14 +115,19 @@ class XverseMoE(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=add_prefix(f"experts.{i}", prefix), ) - for _ in range(self.n_routed_experts) + for i in range(self.n_routed_experts) ] ) self.pack_params() self.router = ReplicatedLinear( - config.hidden_size, self.n_routed_experts, bias=False, quant_config=None + config.hidden_size, + self.n_routed_experts, + bias=False, + quant_config=None, + prefix=add_prefix("router", prefix), ) if config.num_shared_experts is not None: @@ -125,6 +138,7 @@ class XverseMoE(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=add_prefix("shared_experts", prefix), ) def pack_params(self): @@ -182,6 +196,7 @@ class XverseAttention(nn.Module): rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -213,6 +228,7 @@ class XverseAttention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), ) self.o_proj = RowParallelLinear( @@ -220,6 +236,7 @@ class XverseAttention(nn.Module): hidden_size, bias=False, quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), ) self.rotary_emb = get_rope( @@ -235,6 +252,7 @@ class XverseAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, + prefix=add_prefix("attn", prefix), ) def forward( @@ -258,6 +276,7 @@ class XverseDecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -276,15 +295,21 @@ class XverseDecoderLayer(nn.Module): rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), ) if config.num_experts is not None: - self.mlp = XverseMoE(config=config, quant_config=quant_config) + self.mlp = XverseMoE( + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) else: self.mlp = XverseMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=add_prefix("mlp", prefix), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( @@ -324,6 +349,7 @@ class XverseModel(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -332,10 +358,16 @@ class XverseModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), ) self.layers = nn.ModuleList( [ - XverseDecoderLayer(config, layer_id, quant_config=quant_config) + XverseDecoderLayer( + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_id}", prefix), + ) for layer_id in range(config.num_hidden_layers) ] ) @@ -364,13 +396,19 @@ class XverseMoeForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = XverseModel(config, quant_config) + self.model = XverseModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 97ee5946c..4c50b0d3c 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -29,8 +29,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: - super().__init__(config, quant_config) + super().__init__(config, quant_config, prefix=prefix) self.multi_modal_projector = YiVLMultiModalProjector(self.config) self.vision_tower_subfolder = self.config.mm_vision_tower.replace( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index e0afe22a0..76c05749d 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -313,7 +313,7 @@ def make_layers( """Make a list of layers with the given layer function""" modules = torch.nn.ModuleList( [ - maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}")) + maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix))) for idx in range(num_hidden_layers) ] ) @@ -1464,3 +1464,16 @@ def set_cuda_arch(): capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" + + +def add_prefix(name: str, prefix: str) -> str: + """Add a weight path prefix to a module name. + + Args: + name: base module name. + prefix: weight prefix str to added to the front of `name` concatenated with `.`. + + Returns: + The string `prefix.name` if prefix is non-empty, otherwise just `name`. + """ + return name if not prefix else f"{prefix}.{name}" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 8fb0a4314..5ad56e5f5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -12,6 +12,7 @@ suites = { "models/test_generation_models.py", "models/test_qwen_models.py", "models/test_reward_models.py", + "test_gptqmodel_dynamic.py", "test_abort.py", "test_chunked_prefill.py", "test_custom_allreduce.py", diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py new file mode 100644 index 000000000..f22f37f1d --- /dev/null +++ b/test/srt/test_gptqmodel_dynamic.py @@ -0,0 +1,211 @@ +import time +import unittest + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def check_quant_method(model_path: str, use_marlin_kernel: bool): + from sglang.srt.configs.device_config import DeviceConfig + from sglang.srt.configs.load_config import LoadConfig + from sglang.srt.configs.model_config import AttentionArch, ModelConfig + from sglang.srt.distributed import ( + get_tp_group, + init_distributed_environment, + initialize_model_parallel, + set_custom_all_reduce, + ) + from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state + from sglang.srt.layers.quantization import get_dynamic_override + from sglang.srt.model_loader import get_model + from sglang.srt.server_args import PortArgs, ServerArgs + + try: + init_distributed_environment( + backend="nccl", + world_size=1, + rank=0, + local_rank=0, + distributed_init_method="tcp://127.0.0.1:2646", + ) + initialize_model_parallel(tensor_model_parallel_size=1) + monkey_patch_vllm_parallel_state() + except AssertionError: + # ignore this error: tensor model parallel group is already initialized + pass + + server_args = ServerArgs(model_path=model_path, dtype=torch.float16) + model_config = ModelConfig( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + is_embedding=server_args.is_embedding, + dtype=server_args.dtype, + quantization=server_args.quantization, + ) + + load_config = LoadConfig() + device_config = DeviceConfig("cuda") + model = get_model( + model_config=model_config, load_config=load_config, device_config=device_config + ) + + from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod, + ) + + from sglang.srt.layers.linear import UnquantizedLinearMethod + + linear_method_cls = ( + GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) + ) + + for name, submodule in model.named_modules(): + if name == "lm_head": + assert isinstance(submodule.quant_method, linear_method_cls) + elif name == "model.layers.0.self_attn.qkv_proj": + # The first layer is quantized using bits=4, group_size=128 + # desc_act=True + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert config.weight_bits == 4 + assert config.group_size == 128 + assert config.desc_act + elif name == "model.layers.1.self_attn.qkv_proj": + # The second layer is quantized using bits=8, group_size=32 + # desc_act=False + assert isinstance(submodule.quant_method, linear_method_cls) + config = submodule.quant_method.quant_config + assert get_dynamic_override(config, layer_name=name, key="bits") == 8 + assert get_dynamic_override(config, layer_name=name, key="group_size") == 32 + assert not get_dynamic_override(config, layer_name=name, key="desc_act") + elif ( + name == "model.layers.2.self_attn.qkv_proj" + or name == "model.layers.2.mlp.gate_up_proj" + ): + # All other layers (layer index >= 2) are not quantized + assert isinstance(submodule.quant_method, UnquantizedLinearMethod) + + del model + + +# GPTQ with Dynamic Per/Module Quantization Control +# Leverages GPTQModel (pypi) to produce the `dynamic` models +# Test GPTQ fallback kernel that is not Marlin +class TestGPTQModelDynamic(unittest.TestCase): + MODEL_PATH = ( + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse" + ) + + @classmethod + def setUpClass(cls): + cls.model = cls.MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--dtype", "float16"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "max_new_tokens": max_new_tokens, + }, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.time() + result = self.run_decode(max_tokens) + tok = time.time() + + print(f"result = `{result}`") + + assert "paris" in result["text"].lower() + + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 140 + + def test_gptq_module(self): + check_quant_method(self.MODEL_PATH, use_marlin_kernel=False) + + +# GPTQ with Dynamic Per/Module Quantization Control +# Leverages GPTQModel (pypi) to produce the `dynamic` models +# Test Marlin kernel +class TestGPTQModelDynamicWithMarlin(unittest.TestCase): + MODEL_PATH = ( + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue" + ) + + @classmethod + def setUpClass(cls): + cls.model = cls.MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--dtype", "float16"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "max_new_tokens": max_new_tokens, + }, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.time() + result = self.run_decode(max_tokens) + tok = time.time() + + print(f"result = `{result}`") + + assert "paris" in result["text"].lower() + + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 140 + + def test_gptq_marlin_module(self): + check_quant_method(self.MODEL_PATH, use_marlin_kernel=True) + + +if __name__ == "__main__": + unittest.main()