[QUANT] Add GPTQModel Dynamic Quantization + lm_head Quantization (#3790)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
This commit is contained in:
committed by
GitHub
parent
583d6af71b
commit
56a724eba3
@@ -2,15 +2,25 @@
|
||||
|
||||
SGLang supports various quantization methods, including offline quantization and online dynamic quantization.
|
||||
|
||||
Offline quantization loads pre-quantized model weights directly during inference. This is useful for methods requiring pre-computed stats such as AWQ, which collects activation stats from the pre-training set.
|
||||
Offline quantization loads pre-quantized model weights directly during inference. This is required for quantization methods
|
||||
such as GPTQ and AWQ that collects and pre-compute various stats from the original weights using the calibration dataset.
|
||||
|
||||
Online quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime. Like NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors on-the-fly to convert high-precision weights into a lower-precision format.
|
||||
Online quantization dynamically computes scaling parameters—such as the maximum/minimum values of model weights—during runtime.
|
||||
Like NVIDIA FP8 training's [delayed scaling](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#Mixed-precision-training-with-FP8) mechanism, online quantization calculates the appropriate scaling factors
|
||||
on-the-fly to convert high-precision weights into a lower-precision format.
|
||||
|
||||
**Note that, for better performance, usability and convenience, offline quantization is recommended over online quantization.** And if you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time. For popular pre-quantized models, please visit [neuralmagic collection](https://huggingface.co/collections/neuralmagic) for some popular quantized LLMs on huggingface.
|
||||
**Note: For better performance, usability and convenience, offline quantization is recommended over online quantization.**
|
||||
|
||||
If you use a pre-quantized model, do not add `--quantization` to enable online quantization at the same time.
|
||||
For popular pre-quantized models, please visit [ModelCloud](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2) or [NeuralMagic](https://huggingface.co/collections/neuralmagic) collections on HF for some
|
||||
popular quality validated quantized models. Quantized models must be validated via benchmarks post-quantization
|
||||
to guard against abnormal quantization loss regressions.
|
||||
|
||||
## Offline Quantization
|
||||
|
||||
To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.**
|
||||
To load already quantized models, simply load the model weights and config. **Again, if the model has been quantized offline,
|
||||
there's no need to add `--quantization` argument when starting the engine. The quantization method will be parsed from the
|
||||
downloaded Hugging Face config. For example, DeepSeek V3/R1 models are already in FP8, so do not add redundant parameters.**
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server \
|
||||
@@ -18,9 +28,38 @@ python3 -m sglang.launch_server \
|
||||
--port 30000 --host 0.0.0.0
|
||||
```
|
||||
|
||||
To do offline quantization for your model, firstly you need to install [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
|
||||
### Examples of Offline Model Quantization
|
||||
|
||||
#### Using [GPTQModel](https://github.com/ModelCloud/GPTQModel)
|
||||
|
||||
```bash
|
||||
# install
|
||||
pip install gptqmodel --no-build-isolation -v
|
||||
```
|
||||
|
||||
```py
|
||||
from datasets import load_dataset
|
||||
from gptqmodel import GPTQModel, QuantizeConfig
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"
|
||||
|
||||
calibration_dataset = load_dataset(
|
||||
"allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz",
|
||||
split="train"
|
||||
).select(range(1024))["text"]
|
||||
|
||||
quant_config = QuantizeConfig(bits=4, group_size=128) # quantization config
|
||||
model = GPTQModel.load(model_id, quant_config) # load model
|
||||
|
||||
model.quantize(calibration_dataset, batch_size=2) # quantize
|
||||
model.save(quant_path) # save model
|
||||
```
|
||||
|
||||
#### Using [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
|
||||
|
||||
```bash
|
||||
# install
|
||||
pip install llmcompressor
|
||||
```
|
||||
|
||||
@@ -99,8 +138,7 @@ python3 -m sglang.launch_server \
|
||||
|
||||
## Reference
|
||||
|
||||
- [quantization document of vllm](https://docs.vllm.ai/en/latest/quantization/fp8.html)
|
||||
|
||||
- [torchao](https://github.com/pytorch/ao)
|
||||
|
||||
- [llm-compressor](https://github.com/vllm-project/llm-compressor/)
|
||||
- [GPTQModel](https://github.com/ModelCloud/GPTQModel)
|
||||
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
|
||||
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
|
||||
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
|
||||
|
||||
@@ -19,6 +19,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
@@ -122,20 +123,20 @@ class VisionAttention(nn.Module):
|
||||
head_size=self.head_size,
|
||||
total_num_heads=num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
input_size=embed_dim,
|
||||
output_size=3 * projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
input_size=embed_dim,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
prefix=add_prefix("out_proj", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -417,7 +417,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
else:
|
||||
# GGUF models
|
||||
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
||||
|
||||
if self.logit_scale is not None:
|
||||
logits.mul_(self.logit_scale)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Dict, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
@@ -16,8 +18,6 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
|
||||
@@ -61,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}"
|
||||
)
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits."
|
||||
)
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: QuantizationConfig,
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool, None] = None,
|
||||
) -> Union[Dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
UnquantizedEmbeddingMethod,
|
||||
)
|
||||
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = (
|
||||
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
)
|
||||
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if (
|
||||
get_dynamic_override( # noqa: E712
|
||||
cloned_config, layer_name=prefix # noqa: E712
|
||||
)
|
||||
== False
|
||||
): # noqa: E712
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
if isinstance(self, GPTQConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
||||
)
|
||||
elif isinstance(self, GPTQMarlinConfig):
|
||||
return get_linear_quant_method(
|
||||
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -155,6 +256,7 @@ def apply_monkey_patches():
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
||||
|
||||
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
||||
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
||||
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
|
||||
|
||||
|
||||
416
python/sglang/srt/layers/quantization/gptq.py
Normal file
416
python/sglang/srt/layers/quantization/gptq.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import logging
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
super().__init__()
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}"
|
||||
)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQLinearMethod"]:
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
full_config: Dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.full_config = full_config
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
||||
)
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}"
|
||||
)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
desc_act,
|
||||
is_sym,
|
||||
lm_head_quantized,
|
||||
dynamic,
|
||||
config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
||||
)
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = (
|
||||
"The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info(
|
||||
"Detected that the model can run with gptq_marlin"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_marlin for"
|
||||
" faster inference"
|
||||
)
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
|
||||
# if layer.num_experts > 32:
|
||||
# # For MoEs with many experts the moe_wna16 kernel is faster
|
||||
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
# layer, prefix
|
||||
# )
|
||||
# else:
|
||||
# return GPTQMarlinMoEMethod(self)
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# Marlin conversion is only valid if required properties are found
|
||||
if num_bits is None or group_size is None or sym is None or desc_act is None:
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_marlin_supported(
|
||||
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
||||
)
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
"""Config class for Marlin.
|
||||
|
||||
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
"is supported for Marlin, but got group_size of "
|
||||
f"{self.group_size}"
|
||||
)
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // 4
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = 64
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = 128
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = 16
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"MarlinConfig(group_size={self.group_size}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_marlin_format = hf_quant_cfg.get(
|
||||
"checkpoint_format"
|
||||
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
||||
)
|
||||
|
||||
if is_marlin_format and is_valid_user_quant:
|
||||
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
||||
cls.get_name(), cls.get_name()
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
|
||||
v_head_dim: int = -1,
|
||||
sliding_window_size: int = -1,
|
||||
is_cross_attention: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
|
||||
@@ -261,26 +261,27 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedEmbeddingMethod()
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
print("quant_method", quant_method)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method)
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method)
|
||||
)
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
||||
)
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@@ -301,7 +302,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
- self.shard_indices.added_vocab_start_index
|
||||
)
|
||||
|
||||
self.linear_method.create_weights(
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
@@ -446,7 +447,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
packed_factor = (
|
||||
param.packed_factor
|
||||
if isinstance(param, BasevLLMParameter)
|
||||
else param.pack_factor
|
||||
else param.packed_factor
|
||||
)
|
||||
assert loaded_weight.shape[output_dim] == (
|
||||
self.org_vocab_size // param.packed_factor
|
||||
@@ -479,7 +480,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.linear_method.embedding(self, masked_input.long())
|
||||
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
@@ -80,13 +81,22 @@ class BaiChuanMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -114,6 +124,7 @@ class BaiChuanAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id: int = 0,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -167,6 +178,7 @@ class BaiChuanAttention(nn.Module):
|
||||
scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -182,6 +194,7 @@ class BaiChuanAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -207,6 +220,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
position_embedding: str,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -220,12 +234,14 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -264,6 +280,7 @@ class BaiChuanModel(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -281,6 +298,7 @@ class BaiChuanModel(nn.Module):
|
||||
layer_id=i,
|
||||
position_embedding=position_embedding,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -330,18 +348,24 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
||||
self.model = BaiChuanModel(
|
||||
config, position_embedding, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -404,11 +428,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", quant_config)
|
||||
super().__init__(config, "ROPE", quant_config, prefix=prefix)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", quant_config)
|
||||
super().__init__(config, "ALIBI", quant_config, prefix=prefix)
|
||||
|
||||
|
||||
EntryClass = [BaichuanForCausalLM]
|
||||
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
LoraConfig = None
|
||||
|
||||
@@ -51,6 +52,7 @@ class GLMAttention(nn.Module):
|
||||
config,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -85,12 +87,14 @@ class GLMAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("query_key_value", prefix),
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("dense", prefix),
|
||||
)
|
||||
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
@@ -109,6 +113,7 @@ class GLMAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -142,6 +147,7 @@ class GLMMLP(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -153,6 +159,7 @@ class GLMMLP(nn.Module):
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("dense_h_to_4h", prefix),
|
||||
)
|
||||
|
||||
self.activation_func = SiluAndMul()
|
||||
@@ -163,6 +170,7 @@ class GLMMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("dense_4h_to_h", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@@ -186,6 +194,7 @@ class GLMBlock(nn.Module):
|
||||
config,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
@@ -201,7 +210,9 @@ class GLMBlock(nn.Module):
|
||||
)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, layer_id, quant_config)
|
||||
self.self_attention = GLMAttention(
|
||||
config, layer_id, quant_config, prefix=add_prefix("self_attention", prefix)
|
||||
)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
@@ -210,7 +221,7 @@ class GLMBlock(nn.Module):
|
||||
)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config, quant_config)
|
||||
self.mlp = GLMMLP(config, quant_config, prefix=add_prefix("mlp", prefix))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -257,6 +268,7 @@ class GLMTransformer(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
@@ -266,7 +278,15 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList(
|
||||
[GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
|
||||
[
|
||||
GLMBlock(
|
||||
config,
|
||||
i,
|
||||
quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if self.post_layer_norm:
|
||||
@@ -301,19 +321,28 @@ class ChatGLMM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
config.padded_vocab_size, config.hidden_size
|
||||
config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embedding", prefix),
|
||||
)
|
||||
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config, quant_config)
|
||||
self.encoder = GLMTransformer(
|
||||
config, quant_config, add_prefix("encoder", prefix)
|
||||
)
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
|
||||
self.output_layer = ParallelLMHead(
|
||||
config.padded_vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("output_layer", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -351,12 +380,15 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
||||
self.transformer = ChatGLMM(config, quant_config)
|
||||
self.transformer = ChatGLMM(
|
||||
config, quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
||||
from sglang.srt.utils import add_prefix, get_compiler_backend, set_weight_attrs
|
||||
|
||||
|
||||
@torch.compile(backend=get_compiler_backend())
|
||||
@@ -110,6 +110,7 @@ class CohereMLP(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -120,12 +121,14 @@ class CohereMLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@@ -142,6 +145,7 @@ class CohereAttention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -177,12 +181,14 @@ class CohereAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -198,6 +204,7 @@ class CohereAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm(
|
||||
@@ -239,15 +246,23 @@ class CohereDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CohereAttention(
|
||||
config, layer_id=layer_id, quant_config=quant_config
|
||||
config,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||
self.mlp = CohereMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = LayerNorm(
|
||||
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
||||
)
|
||||
@@ -279,6 +294,7 @@ class CohereModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -288,7 +304,12 @@ class CohereModel(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
CohereDecoderLayer(config, i, quant_config=quant_config)
|
||||
CohereDecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -321,12 +342,15 @@ class CohereForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.model = CohereModel(config, quant_config)
|
||||
self.model = CohereModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import add_prefix, set_weight_attrs
|
||||
|
||||
|
||||
class DbrxRouter(nn.Module):
|
||||
@@ -58,6 +58,7 @@ class DbrxRouter(nn.Module):
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -89,6 +90,7 @@ class DbrxExperts(nn.Module):
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -189,6 +191,7 @@ class DbrxAttention(nn.Module):
|
||||
config: DbrxConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@@ -207,12 +210,14 @@ class DbrxAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("Wqkv", prefix),
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("out_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -244,6 +249,7 @@ class DbrxAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -268,10 +274,16 @@ class DbrxFusedNormAttention(nn.Module):
|
||||
config: DbrxConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.attn = DbrxAttention(config, layer_id, quant_config=quant_config)
|
||||
self.attn = DbrxAttention(
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.norm_1 = nn.LayerNorm(self.d_model)
|
||||
self.norm_2 = nn.LayerNorm(self.d_model)
|
||||
|
||||
@@ -300,10 +312,14 @@ class DbrxBlock(nn.Module):
|
||||
config: DbrxConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(
|
||||
config, layer_id, quant_config=quant_config
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("norm_attn_norm", prefix),
|
||||
)
|
||||
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
||||
|
||||
@@ -328,6 +344,7 @@ class DbrxModel(nn.Module):
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.wte = VocabParallelEmbedding(
|
||||
@@ -336,7 +353,12 @@ class DbrxModel(nn.Module):
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
DbrxBlock(config, i, quant_config=quant_config)
|
||||
DbrxBlock(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"blocks.{i}", prefix),
|
||||
)
|
||||
for i in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
@@ -369,17 +391,21 @@ class DbrxForCausalLM(nn.Module):
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.transformer = DbrxModel(config, quant_config=quant_config)
|
||||
self.transformer = DbrxModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
@@ -57,10 +58,15 @@ class DeepseekMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -68,6 +74,7 @@ class DeepseekMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -89,6 +96,7 @@ class DeepseekMoE(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -110,6 +118,7 @@ class DeepseekMoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix(f"{idx}.experts", prefix),
|
||||
)
|
||||
for idx in range(self.n_routed_experts)
|
||||
]
|
||||
@@ -117,7 +126,11 @@ class DeepseekMoE(nn.Module):
|
||||
self.pack_params()
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
|
||||
config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
@@ -128,6 +141,7 @@ class DeepseekMoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
)
|
||||
|
||||
def pack_params(self):
|
||||
@@ -185,6 +199,7 @@ class DeepseekAttention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -216,6 +231,7 @@ class DeepseekAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -223,6 +239,7 @@ class DeepseekAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -238,6 +255,7 @@ class DeepseekAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -261,6 +279,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -276,19 +295,25 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
):
|
||||
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
|
||||
self.mlp = DeepseekMoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepseekMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -328,6 +353,7 @@ class DeepseekModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -339,7 +365,12 @@ class DeepseekModel(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
|
||||
DeepseekDecoderLayer(
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -368,13 +399,19 @@ class DeepseekForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(config, quant_config)
|
||||
self.model = DeepseekModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import add_prefix, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
@@ -48,6 +48,7 @@ class DeepseekModelNextN(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
@@ -56,6 +57,7 @@ class DeepseekModelNextN(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -64,7 +66,11 @@ class DeepseekModelNextN(nn.Module):
|
||||
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
self.decoder = DeepseekV2DecoderLayer(
|
||||
config, 0, quant_config=quant_config, is_nextn=True
|
||||
config,
|
||||
0,
|
||||
quant_config=quant_config,
|
||||
is_nextn=True,
|
||||
prefix=add_prefix("decoder", prefix),
|
||||
)
|
||||
|
||||
self.shared_head = nn.Module()
|
||||
@@ -108,18 +114,22 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = DeepseekModelNextN(config, quant_config)
|
||||
self.model = DeepseekModelNextN(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.lm_head = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
prefix=add_prefix("model.shared_head.head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
@@ -127,6 +137,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model.shared_head.head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import is_cuda_available, is_hip
|
||||
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
|
||||
@@ -79,10 +79,15 @@ class DeepseekV2MLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -90,6 +95,7 @@ class DeepseekV2MLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -106,7 +112,11 @@ class DeepseekV2MLP(nn.Module):
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((config.n_routed_experts, config.hidden_size))
|
||||
@@ -129,6 +139,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -147,7 +158,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config=config)
|
||||
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
||||
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
@@ -161,6 +172,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
@@ -171,6 +183,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -217,6 +230,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -241,6 +255,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
@@ -248,6 +263,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
@@ -255,6 +271,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -262,8 +279,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
# FIXME: quick fix for skip quantization
|
||||
prefix=f"self_attn.kv_a_proj_with_mqa",
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
@@ -271,6 +287,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -278,6 +295,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
@@ -303,6 +321,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -368,6 +387,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
use_dp=False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -394,6 +414,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ReplicatedLinear(
|
||||
@@ -401,6 +422,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ReplicatedLinear(
|
||||
@@ -408,12 +430,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
self.kv_b_proj = ReplicatedLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = ReplicatedLinear(
|
||||
@@ -421,6 +445,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
else:
|
||||
# For tensor parallel attention
|
||||
@@ -430,6 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
@@ -437,6 +463,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
@@ -444,12 +471,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -457,6 +486,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -464,8 +494,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
# FIXME: quick fix for skip quantization
|
||||
prefix=f"self_attn.kv_a_proj_with_mqa",
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -496,6 +525,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
num_kv_heads=1,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
prefix=add_prefix("attn_mqa", prefix),
|
||||
)
|
||||
|
||||
self.attn_mha = RadixAttention(
|
||||
@@ -505,6 +535,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.v_head_dim,
|
||||
prefix=add_prefix("attn_mha", prefix),
|
||||
)
|
||||
|
||||
self.w_kc = None
|
||||
@@ -848,6 +879,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_nextn: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -880,6 +912,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
use_dp=self.enable_dp_attention,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
else:
|
||||
self.self_attn = DeepseekV2Attention(
|
||||
@@ -898,19 +931,25 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
if is_nextn or (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
):
|
||||
self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
|
||||
self.mlp = DeepseekV2MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
self.mlp = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -962,6 +1001,7 @@ class DeepseekV2Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_id = config.pad_token_id
|
||||
@@ -978,6 +1018,7 @@ class DeepseekV2Model(nn.Module):
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -1008,21 +1049,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekV2Model(config, quant_config)
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if global_server_args_dict["enable_dp_attention"]:
|
||||
self.lm_head = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class ExaoneGatedMLP(nn.Module):
|
||||
@@ -56,14 +57,14 @@ class ExaoneGatedMLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -130,14 +131,14 @@ class ExaoneAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
prefix=add_prefix("out_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -201,14 +202,14 @@ class ExaoneDecoderLayer(nn.Module):
|
||||
rope_is_neox_style=rope_is_neox_style,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = ExaoneGatedMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.activation_function,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
rms_norm_eps = config.layer_norm_epsilon
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
@@ -244,6 +245,7 @@ class ExaoneModel(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -256,7 +258,10 @@ class ExaoneModel(nn.Module):
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
ExaoneDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.h.{i}"
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"h.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -293,12 +298,17 @@ class ExaoneForCausalLM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = ExaoneModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.transformer = ExaoneModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
@@ -45,6 +46,7 @@ class GemmaMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -52,12 +54,14 @@ class GemmaMLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
self.act_fn = GeluAndMul("none")
|
||||
|
||||
@@ -79,6 +83,7 @@ class GemmaAttention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -109,12 +114,14 @@ class GemmaAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -130,6 +137,7 @@ class GemmaAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -152,6 +160,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -164,11 +173,13 @@ class GemmaDecoderLayer(nn.Module):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
rope_theta=config.rope_theta,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = GemmaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -205,6 +216,7 @@ class GemmaModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -215,7 +227,12 @@ class GemmaModel(nn.Module):
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GemmaDecoderLayer(config, i, quant_config=quant_config)
|
||||
GemmaDecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -277,11 +294,14 @@ class GemmaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = GemmaModel(config, quant_config=quant_config)
|
||||
self.model = GemmaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -39,7 +39,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
||||
@@ -56,13 +56,22 @@ class Gemma2MLP(nn.Module):
|
||||
hidden_act: str,
|
||||
hidden_activation: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
||||
raise ValueError(
|
||||
@@ -91,6 +100,7 @@ class Gemma2Attention(nn.Module):
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -123,12 +133,14 @@ class Gemma2Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -151,6 +163,7 @@ class Gemma2Attention(nn.Module):
|
||||
if use_sliding_window
|
||||
else None
|
||||
),
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -173,6 +186,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
layer_id: int,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -186,6 +200,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
rope_theta=config.rope_theta,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.mlp = Gemma2MLP(
|
||||
@@ -194,6 +209,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
hidden_activation=config.hidden_activation,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
@@ -238,6 +254,7 @@ class Gemma2Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -253,7 +270,7 @@ class Gemma2Model(nn.Module):
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
prefix="",
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -339,11 +356,14 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma2Model(config, quant_config)
|
||||
self.model = Gemma2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.gemma2 import Gemma2ForCausalLM, Gemma2Model
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class Gemma2ForSequenceClassification(nn.Module):
|
||||
@@ -29,12 +30,15 @@ class Gemma2ForSequenceClassification(nn.Module):
|
||||
self,
|
||||
config: Gemma2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.num_labels = config.num_labels
|
||||
self.model = Gemma2Model(config, quant_config=quant_config)
|
||||
self.model = Gemma2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
@@ -62,14 +63,14 @@ class GPT2Attention(nn.Module):
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_attn",
|
||||
prefix=add_prefix("c_attn", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -108,14 +109,14 @@ class GPT2MLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
prefix=add_prefix("c_fc", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
self.act = act_layer()
|
||||
|
||||
@@ -145,7 +146,7 @@ class GPT2Block(nn.Module):
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(
|
||||
layer_id, config, quant_config, prefix=f"{prefix}.attn"
|
||||
layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(
|
||||
@@ -153,7 +154,7 @@ class GPT2Block(nn.Module):
|
||||
config,
|
||||
act_layer=act_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -196,7 +197,12 @@ class GPT2Model(nn.Module):
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPT2Block(i, config, quant_config=quant_config)
|
||||
GPT2Block(
|
||||
i,
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"h.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -227,11 +233,14 @@ class GPT2LMHeadModel(nn.Module):
|
||||
self,
|
||||
config: GPT2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config, quant_config, prefix="transformer")
|
||||
self.transformer = GPT2Model(
|
||||
config, quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = self.transformer.wte
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
@@ -44,6 +45,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -69,6 +71,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_attn", prefix),
|
||||
)
|
||||
|
||||
self.c_proj = RowParallelLinear(
|
||||
@@ -76,6 +79,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
@@ -83,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
scaling=self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -111,6 +116,7 @@ class GPTBigMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@@ -119,12 +125,14 @@ class GPTBigMLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_fc", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
self.act = get_act_fn(
|
||||
config.activation_function, quant_config, intermediate_size
|
||||
@@ -144,15 +152,20 @@ class GPTBigCodeBlock(nn.Module):
|
||||
layer_id: int,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
|
||||
self.attn = GPTBigCodeAttention(
|
||||
layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
|
||||
)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
self.mlp = GPTBigMLP(
|
||||
inner_dim, config, quant_config, prefix=add_prefix("mlp", prefix)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -181,6 +194,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -190,12 +204,17 @@ class GPTBigCodeModel(nn.Module):
|
||||
lora_vocab = 0
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.wte = VocabParallelEmbedding(
|
||||
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
||||
self.vocab_size,
|
||||
self.embed_dim,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("wte", prefix),
|
||||
)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
GPTBigCodeBlock(i, config, quant_config)
|
||||
GPTBigCodeBlock(
|
||||
i, config, quant_config, prefix=add_prefix(f"h.{i}", prefix)
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -235,13 +254,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(config, quant_config)
|
||||
self.transformer = GPTBigCodeModel(
|
||||
config, quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
self.lm_head = self.transformer.wte
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -62,14 +63,14 @@ class GraniteMLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -133,14 +134,14 @@ class GraniteAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -157,6 +158,7 @@ class GraniteAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -205,14 +207,14 @@ class GraniteDecoderLayer(nn.Module):
|
||||
rope_is_neox_style=rope_is_neox_style,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = GraniteMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -252,6 +254,7 @@ class GraniteModel(nn.Module):
|
||||
self,
|
||||
config: GraniteConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -263,7 +266,10 @@ class GraniteModel(nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GraniteDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -300,17 +306,23 @@ class GraniteForCausalLM(nn.Module):
|
||||
self,
|
||||
config: GraniteConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = GraniteModel(config, quant_config=quant_config)
|
||||
self.model = GraniteModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# If tie_word_embeddings == True, then input and output embeddings are
|
||||
# the same tensor. Enforce during object creation so that weights will
|
||||
# load correctly even if the LM head weights don't have a separate entry
|
||||
# in the state dict.
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
@@ -47,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class Grok1MLP(nn.Module):
|
||||
@@ -65,7 +66,7 @@ class Grok1MLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
@@ -73,7 +74,7 @@ class Grok1MLP(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
reduce_results=reduce_results,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
@@ -107,6 +108,7 @@ class Grok1MoE(nn.Module):
|
||||
tp_size: Optional[int] = None,
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -118,6 +120,7 @@ class Grok1MoE(nn.Module):
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
self.router_logit_softcapping = getattr(
|
||||
@@ -135,6 +138,7 @@ class Grok1MoE(nn.Module):
|
||||
tp_size=tp_size,
|
||||
activation="gelu",
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -163,6 +167,7 @@ class Grok1Attention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -195,6 +200,7 @@ class Grok1Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
@@ -202,6 +208,7 @@ class Grok1Attention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -220,6 +227,7 @@ class Grok1Attention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
logit_cap=logit_cap,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -243,6 +251,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
@@ -259,6 +268,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.block_sparse_moe = Grok1MoE(
|
||||
config=config,
|
||||
@@ -273,6 +283,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
prefix=add_prefix("block_sparse_moe", prefix),
|
||||
)
|
||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -311,6 +322,7 @@ class Grok1Model(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -320,6 +332,7 @@ class Grok1Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@@ -328,6 +341,7 @@ class Grok1Model(nn.Module):
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -359,6 +373,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -377,8 +392,11 @@ class Grok1ForCausalLM(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
prefix=add_prefix("model", prefix),
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class InternLM2MLP(nn.Module):
|
||||
@@ -47,13 +48,22 @@ class InternLM2MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.w2 = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("w2", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -80,6 +90,7 @@ class InternLM2Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -111,12 +122,14 @@ class InternLM2Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("wqkv", prefix),
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("wo", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -127,7 +140,12 @@ class InternLM2Attention(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
self.num_kv_heads,
|
||||
layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -150,6 +168,7 @@ class InternLMDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -165,12 +184,14 @@ class InternLMDecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attention", prefix),
|
||||
)
|
||||
self.feed_forward = InternLM2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("feed_forward", prefix),
|
||||
)
|
||||
self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -205,6 +226,7 @@ class InternLM2Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -213,10 +235,13 @@ class InternLM2Model(nn.Module):
|
||||
self.tok_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("tok_embeddings", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
InternLMDecoderLayer(config, i, quant_config)
|
||||
InternLMDecoderLayer(
|
||||
config, i, quant_config, prefix=add_prefix(f"layers.{i}", prefix)
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -251,12 +276,17 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = InternLM2Model(config, quant_config)
|
||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.model = InternLM2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.output = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("output", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class InternLM2ForRewardModel(nn.Module):
|
||||
@@ -29,12 +30,15 @@ class InternLM2ForRewardModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = InternLM2Model(config, quant_config)
|
||||
self.model = InternLM2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
kv_cache_scales_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -70,14 +70,14 @@ class LlamaMLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -142,14 +142,14 @@ class LlamaAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -166,6 +166,7 @@ class LlamaAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -218,7 +219,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_is_neox_style=rope_is_neox_style,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
bias=attention_bias,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
@@ -226,7 +227,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -263,6 +264,7 @@ class LlamaModel(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -272,6 +274,7 @@ class LlamaModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -358,18 +361,24 @@ class LlamaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
||||
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class LlamaForClassification(nn.Module):
|
||||
@@ -30,11 +31,14 @@ class LlamaForClassification(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
|
||||
self.classification_head = nn.Linear(
|
||||
config.hidden_size, config.classification_out_size, bias=False
|
||||
|
||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
||||
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
||||
@@ -55,6 +57,7 @@ class LlamaModel(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -62,11 +65,15 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -106,24 +113,26 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
||||
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
if hasattr(config, "hot_vocab_size"):
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.hot_vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
getattr(config, "hot_vocab_size", config.vocab_size),
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -8,6 +8,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaModel
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module):
|
||||
@@ -15,9 +16,12 @@ class LlamaEmbeddingModel(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class LlamaForSequenceClassification(nn.Module):
|
||||
@@ -29,12 +30,15 @@ class LlamaForSequenceClassification(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.num_labels = config.num_labels
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
||||
|
||||
@@ -82,8 +86,9 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, quant_config)
|
||||
super().__init__(config, quant_config, prefix=prefix)
|
||||
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.mistral import MistralForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class LlavaBaseForCausalLM(nn.Module):
|
||||
@@ -475,6 +476,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -484,7 +486,11 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
||||
self.config.text_config.hidden_size = config.hidden_size
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
self.language_model = LlamaForCausalLM(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
@@ -496,6 +502,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -516,7 +523,11 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
||||
self.config.image_token_index = 151646
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
|
||||
self.language_model = Qwen2ForCausalLM(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
@@ -528,6 +539,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -548,7 +560,11 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
||||
self.config.image_token_index = 32000
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
|
||||
self.language_model = MistralForCausalLM(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
|
||||
@@ -26,6 +26,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class LlavaVidForCausalLM(nn.Module):
|
||||
@@ -33,6 +34,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -44,7 +46,11 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
self.resampler = nn.AvgPool2d(
|
||||
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
||||
)
|
||||
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
||||
self.language_model = LlamaForCausalLM(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
self.num_frames = getattr(self.config, "num_frames", 16)
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
@@ -46,6 +47,7 @@ class MiniCPMMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -53,12 +55,14 @@ class MiniCPMMLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -85,6 +89,7 @@ class MiniCPMAttention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -116,12 +121,14 @@ class MiniCPMAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -139,6 +146,7 @@ class MiniCPMAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -164,6 +172,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
config,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -180,12 +189,14 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = MiniCPMMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -227,6 +238,7 @@ class MiniCPMModel(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -236,10 +248,16 @@ class MiniCPMModel(nn.Module):
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MiniCPMDecoderLayer(config, i, quant_config=quant_config)
|
||||
MiniCPMDecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -275,19 +293,23 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPMModel(config, quant_config=quant_config)
|
||||
self.model = MiniCPMModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if not self.config.tie_word_embeddings:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import add_prefix, is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import bmm_fp8
|
||||
@@ -53,6 +53,7 @@ class MiniCPM3MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -60,12 +61,14 @@ class MiniCPM3MLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -107,6 +110,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -131,6 +135,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
@@ -138,6 +143,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
@@ -145,6 +151,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -152,6 +159,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
@@ -159,6 +167,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -166,6 +175,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
@@ -182,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -250,6 +261,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -274,6 +286,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_a_proj", prefix),
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
@@ -281,6 +294,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_b_proj", prefix),
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
@@ -288,6 +302,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("q_proj", prefix),
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -295,6 +310,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
@@ -302,6 +318,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("kv_b_proj", prefix),
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -309,6 +326,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
@@ -325,6 +343,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
||||
num_kv_heads=1,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
self.w_kc = None
|
||||
@@ -405,6 +424,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -429,6 +449,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
else:
|
||||
self.self_attn = MiniCPM3Attention(
|
||||
@@ -447,12 +468,14 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = MiniCPM3MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -494,6 +517,7 @@ class MiniCPM3Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -503,10 +527,16 @@ class MiniCPM3Model(nn.Module):
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
|
||||
MiniCPM3DecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -542,19 +572,23 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPM3Model(config, quant_config=quant_config)
|
||||
self.model = MiniCPM3Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if not self.config.tie_word_embeddings:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||
|
||||
@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
RawImageType = Union[Image.Image, torch.Tensor]
|
||||
|
||||
@@ -158,14 +159,14 @@ class Idefics2VisionMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
prefix=add_prefix("fc1", prefix),
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
prefix=add_prefix("fc2", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -199,10 +200,14 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=True,
|
||||
flatten_batch=False,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Idefics2VisionMLP(config, quant_config=quant_config)
|
||||
self.mlp = Idefics2VisionMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
@@ -242,6 +247,7 @@ class Idefics2Encoder(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -251,8 +257,9 @@ class Idefics2Encoder(nn.Module):
|
||||
Idefics2EncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -379,13 +386,18 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
self.config = config
|
||||
self.embeddings = Idefics2VisionEmbeddings(config)
|
||||
self.encoder = Idefics2Encoder(config=config, quant_config=quant_config)
|
||||
self.encoder = Idefics2Encoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("encoder", prefix),
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@@ -503,7 +515,7 @@ class BaseResampler(nn.Module):
|
||||
embed_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_proj",
|
||||
prefix=add_prefix("kv_proj", prefix),
|
||||
)
|
||||
else:
|
||||
# Maintain the same return value with ReplicatedLinear.forward
|
||||
@@ -660,6 +672,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
*,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
# All MiniCPM-V models disable `tie_word_embeddings` but
|
||||
@@ -669,8 +682,12 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(config=config, quant_config=quant_config)
|
||||
self.vpm = self.init_vision_module(config, quant_config)
|
||||
self.llm = self.init_llm(
|
||||
config=config, quant_config=quant_config, prefix=add_prefix("llm", prefix)
|
||||
)
|
||||
self.vpm = self.init_vision_module(
|
||||
config, quant_config, add_prefix("vpm", prefix)
|
||||
)
|
||||
self.vision_dim = (
|
||||
self.vpm.embed_dim
|
||||
if self.version == (2, 0)
|
||||
@@ -679,7 +696,10 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
self.resampler = self.init_resampler(
|
||||
self.embed_dim, self.vision_dim, quant_config=quant_config
|
||||
self.embed_dim,
|
||||
self.vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("resampler", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
@@ -937,6 +957,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -944,6 +965,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -952,6 +974,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1011,24 +1034,27 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(config=config, quant_config=quant_config)
|
||||
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
||||
assert self.version == (2, 6)
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
return Qwen2ForCausalLM(config=config, quant_config=quant_config)
|
||||
return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
|
||||
|
||||
def init_vision_module(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
model = Idefics2VisionTransformer(
|
||||
config=config.vision_config, quant_config=quant_config
|
||||
config=config.vision_config, quant_config=quant_config, prefix=prefix
|
||||
)
|
||||
if self.config.drop_vision_last_layer:
|
||||
model.encoder.layers = model.encoder.layers[:-1]
|
||||
@@ -1042,6 +1068,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
embed_dim: int,
|
||||
vision_dim: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
# The resampler in 2.6 remains consistent with the one in 2.5.
|
||||
@@ -1051,6 +1078,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
|
||||
@@ -1207,6 +1235,7 @@ class MiniCPMV:
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1221,7 +1250,9 @@ class MiniCPMV:
|
||||
raise ValueError("Currently, MiniCPMV only supports versions 2.6")
|
||||
|
||||
try:
|
||||
minicpmv = instance_class(config=config, quant_config=quant_config)
|
||||
minicpmv = instance_class(
|
||||
config=config, quant_config=quant_config, prefix=prefix
|
||||
)
|
||||
self.minicpmv = minicpmv
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate MiniCPMV: {e}")
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
@@ -78,7 +79,7 @@ class MixtralMoE(nn.Module):
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
@@ -90,7 +91,7 @@ class MixtralMoE(nn.Module):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts",
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -146,14 +147,14 @@ class MixtralAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -168,6 +169,7 @@ class MixtralAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -204,7 +206,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
@@ -212,7 +214,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.block_sparse_moe",
|
||||
prefix=add_prefix("block_sparse_moe", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -258,11 +260,15 @@ class MixtralModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -296,12 +302,17 @@ class MixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.model = MixtralModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
@@ -54,6 +55,7 @@ class MixtralMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
@@ -61,13 +63,25 @@ class MixtralMLP(nn.Module):
|
||||
self.hidden_dim = hidden_size
|
||||
|
||||
self.w1 = ReplicatedLinear(
|
||||
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
||||
self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("w1", prefix),
|
||||
)
|
||||
self.w2 = ReplicatedLinear(
|
||||
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
||||
self.ffn_dim,
|
||||
self.hidden_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("w2", prefix),
|
||||
)
|
||||
self.w3 = ReplicatedLinear(
|
||||
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
||||
self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("w3", prefix),
|
||||
)
|
||||
|
||||
# TODO: Use vllm's SiluAndMul
|
||||
@@ -87,6 +101,7 @@ class MixtralMoE(nn.Module):
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -114,6 +129,7 @@ class MixtralMoE(nn.Module):
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"experts.{idx}", prefix),
|
||||
)
|
||||
if idx in self.expert_indicies
|
||||
else None
|
||||
@@ -122,7 +138,11 @@ class MixtralMoE(nn.Module):
|
||||
]
|
||||
)
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
||||
config.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -159,6 +179,7 @@ class MixtralAttention(nn.Module):
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -189,12 +210,14 @@ class MixtralAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -209,6 +232,7 @@ class MixtralAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -231,6 +255,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
config: MixtralConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -244,8 +269,13 @@ class MixtralDecoderLayer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("block_sparse_moe", prefix),
|
||||
)
|
||||
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
@@ -281,6 +311,7 @@ class MixtralModel(nn.Module):
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -289,10 +320,16 @@ class MixtralModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
||||
MixtralDecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -324,12 +361,17 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MixtralModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.model = MixtralModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -36,6 +36,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||
@@ -147,7 +148,12 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||||
|
||||
|
||||
class MllamaVisionMLP(nn.Module):
|
||||
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
@@ -156,12 +162,14 @@ class MllamaVisionMLP(nn.Module):
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("fc1", prefix),
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("fc2", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -174,7 +182,10 @@ class MllamaVisionMLP(nn.Module):
|
||||
|
||||
class MllamaVisionEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False
|
||||
self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
is_gated: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -193,8 +204,9 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=False,
|
||||
flatten_batch=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = MllamaVisionMLP(config)
|
||||
self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
@@ -235,11 +247,17 @@ class MllamaVisionEncoder(nn.Module):
|
||||
num_layers=32,
|
||||
is_gated=False,
|
||||
output_hidden_states=None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]
|
||||
[
|
||||
MllamaVisionEncoderLayer(
|
||||
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.output_hidden_states = output_hidden_states or []
|
||||
|
||||
@@ -265,7 +283,7 @@ class MllamaVisionEncoder(nn.Module):
|
||||
|
||||
|
||||
class MllamaVisionModel(nn.Module):
|
||||
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
||||
def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
@@ -305,9 +323,13 @@ class MllamaVisionModel(nn.Module):
|
||||
config.num_hidden_layers,
|
||||
is_gated=False,
|
||||
output_hidden_states=config.intermediate_layers_indices,
|
||||
prefix=add_prefix("transformer", prefix),
|
||||
)
|
||||
self.global_transformer = MllamaVisionEncoder(
|
||||
config, config.num_global_layers, is_gated=True
|
||||
config,
|
||||
config.num_global_layers,
|
||||
is_gated=True,
|
||||
prefix=add_prefix("global_transformer", prefix),
|
||||
)
|
||||
|
||||
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
@@ -464,6 +486,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
config: Optional[config_mllama.MllamaTextConfig] = None,
|
||||
layer_id: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -489,6 +512,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.num_key_value_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
@@ -496,6 +520,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
||||
# use huggingface's instead
|
||||
@@ -510,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
self.num_local_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
is_cross_attention=True,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -551,6 +577,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -558,6 +585,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("cross_attn", prefix),
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -568,6 +596,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
@@ -610,12 +639,15 @@ class MllamaTextModel(nn.Module):
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.padding_id = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size + 8, config.hidden_size
|
||||
config.vocab_size + 8,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.cross_attention_layers = config.cross_attention_layers
|
||||
|
||||
@@ -624,14 +656,20 @@ class MllamaTextModel(nn.Module):
|
||||
if layer_id in self.cross_attention_layers:
|
||||
layers.append(
|
||||
MllamaCrossAttentionDecoderLayer(
|
||||
config, layer_id, quant_config=quant_config
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
||||
layers.append(
|
||||
LlamaDecoderLayer(
|
||||
config, quant_config=quant_config, layer_id=layer_id
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -687,16 +725,20 @@ class MllamaForCausalLM(nn.Module):
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = MllamaTextModel(config, quant_config)
|
||||
self.model = MllamaTextModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -726,6 +768,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
config: config_mllama.MllamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
@@ -737,10 +780,13 @@ class MllamaForConditionalGeneration(nn.Module):
|
||||
)
|
||||
self.image_size = config.vision_config.image_size
|
||||
|
||||
self.vision_model = MllamaVisionModel(config.vision_config)
|
||||
self.vision_model = MllamaVisionModel(
|
||||
config.vision_config, prefix=add_prefix("vision_model", prefix)
|
||||
)
|
||||
self.language_model = MllamaForCausalLM(
|
||||
config.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
self.multi_modal_projector = nn.Linear(
|
||||
config.vision_config.vision_output_dim,
|
||||
|
||||
@@ -38,7 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
@@ -53,6 +53,7 @@ class OlmoAttention(nn.Module):
|
||||
config: OlmoConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -75,6 +76,7 @@ class OlmoAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=config.attention_bias,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
# Rotary embeddings.
|
||||
@@ -91,6 +93,7 @@ class OlmoAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
# Attention output projection.
|
||||
@@ -98,6 +101,7 @@ class OlmoAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -127,6 +131,7 @@ class OlmoMLP(nn.Module):
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -139,6 +144,7 @@ class OlmoMLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
@@ -150,6 +156,7 @@ class OlmoMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -174,13 +181,23 @@ class OlmoDecoderLayer(nn.Module):
|
||||
config: OlmoConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = OlmoAttention(config, layer_id, quant_config)
|
||||
self.self_attn = OlmoAttention(
|
||||
config,
|
||||
layer_id,
|
||||
quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = OlmoMLP(config, quant_config)
|
||||
self.mlp = OlmoMLP(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
@@ -213,13 +230,18 @@ class OlmoDecoderLayer(nn.Module):
|
||||
class OlmoModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -227,7 +249,9 @@ class OlmoModel(nn.Module):
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = nn.LayerNorm(
|
||||
config.hidden_size, elementwise_affine=False, bias=False
|
||||
@@ -275,10 +299,11 @@ class OlmoForCausalLM(nn.Module):
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = OlmoModel(config, quant_config)
|
||||
self.model = OlmoModel(config, quant_config, prefix=add_prefix("model", prefix))
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
@@ -288,6 +313,7 @@ class OlmoForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
class Olmo2Attention(nn.Module):
|
||||
@@ -60,6 +60,7 @@ class Olmo2Attention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
@@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
# Attention output projection.
|
||||
@@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module):
|
||||
self.head_dim * self.total_num_heads,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
def _apply_qk_norm(
|
||||
@@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
@@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = Olmo2Attention(config, layer_id, quant_config)
|
||||
self.self_attn = Olmo2Attention(
|
||||
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
|
||||
)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = Olmo2MLP(config, quant_config)
|
||||
self.mlp = Olmo2MLP(config, quant_config, prefix=add_prefix("mlp", prefix))
|
||||
|
||||
# RMSNorm
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -254,12 +266,15 @@ class Olmo2Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -267,7 +282,9 @@ class Olmo2Model(nn.Module):
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Olmo2Model(config, quant_config)
|
||||
self.model = Olmo2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
@@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers, print_warning_once
|
||||
from sglang.srt.utils import add_prefix, make_layers, print_warning_once
|
||||
|
||||
|
||||
class OlmoeMoE(nn.Module):
|
||||
@@ -69,7 +69,11 @@ class OlmoeMoE(nn.Module):
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(
|
||||
hidden_size, num_experts, bias=False, quant_config=None
|
||||
hidden_size,
|
||||
num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
@@ -81,6 +85,7 @@ class OlmoeMoE(nn.Module):
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -107,6 +112,7 @@ class OlmoeAttention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 4096,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -138,6 +144,7 @@ class OlmoeAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
|
||||
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
|
||||
@@ -146,6 +153,7 @@ class OlmoeAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -162,6 +170,7 @@ class OlmoeAttention(nn.Module):
|
||||
self.scaling,
|
||||
layer_id=layer_id,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -186,6 +195,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -202,6 +212,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
self.mlp = OlmoeMoE(
|
||||
@@ -210,6 +221,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
@@ -246,6 +258,7 @@ class OlmoeModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -254,6 +267,7 @@ class OlmoeModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -261,7 +275,9 @@ class OlmoeModel(nn.Module):
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
layer_id=idx,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
@@ -294,13 +310,19 @@ class OlmoeForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OlmoeModel(config, quant_config)
|
||||
self.model = OlmoeModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -70,13 +70,14 @@ class Phi3SmallMLP(nn.Module):
|
||||
2 * [self.intermediate_size],
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj",
|
||||
prefix=add_prefix("up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -140,7 +141,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
self.num_key_value_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.dense = RowParallelLinear(
|
||||
@@ -148,7 +149,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
if getattr(self.config, "rope_scaling", None) is not None:
|
||||
@@ -201,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads_per_partion,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -234,13 +236,21 @@ class Phi3SmallDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Phi3SmallSelfAttention(
|
||||
config, layer_id, quant_config=quant_config
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = Phi3SmallMLP(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.mlp = Phi3SmallMLP(config, quant_config)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_epsilon
|
||||
@@ -284,15 +294,20 @@ class Phi3SmallModel(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Phi3SmallDecoderLayer(
|
||||
config, int(prefix.split(".")[-1]), quant_config
|
||||
config,
|
||||
int(prefix.split(".")[-1]),
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
@@ -335,6 +350,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
self,
|
||||
config: Phi3Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@@ -344,7 +360,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
self.model = Phi3SmallModel(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix="model",
|
||||
prefix=add_prefix("model", prefix),
|
||||
)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.mup_width_multiplier = config.mup_width_multiplier
|
||||
@@ -354,6 +370,7 @@ class Phi3SmallForCausalLM(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
|
||||
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
@@ -48,6 +49,7 @@ class QWenMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "silu",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -56,6 +58,7 @@ class QWenMLP(nn.Module):
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -63,6 +66,7 @@ class QWenMLP(nn.Module):
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -88,6 +92,7 @@ class QWenAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -104,6 +109,7 @@ class QWenAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_attn", prefix),
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
@@ -111,6 +117,7 @@ class QWenAttention(nn.Module):
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("c_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -126,6 +133,7 @@ class QWenAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -148,6 +156,7 @@ class QWenBlock(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@@ -162,6 +171,7 @@ class QWenBlock(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@@ -170,6 +180,7 @@ class QWenBlock(nn.Module):
|
||||
config.hidden_size,
|
||||
config.intermediate_size // 2,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -201,6 +212,7 @@ class QWenModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -210,10 +222,16 @@ class QWenModel(nn.Module):
|
||||
self.wte = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("wte", prefix),
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
QWenBlock(config, i, quant_config=quant_config)
|
||||
QWenBlock(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"h.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -242,12 +260,17 @@ class QWenLMHeadModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = QWenModel(config, quant_config=quant_config)
|
||||
self.transformer = QWenModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
||||
)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||
self.lm_head = ParallelLMHead(
|
||||
vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# Adapted from llama2.py
|
||||
# Modify details for the adaptation of Qwen2 model.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
|
||||
from readline import add_history
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
kv_cache_scales_loader,
|
||||
)
|
||||
from sglang.srt.utils import make_layers
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
Qwen2Config = None
|
||||
|
||||
@@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 32768,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
config: Qwen2Config,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -228,6 +238,7 @@ class Qwen2Model(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -237,6 +248,7 @@ class Qwen2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
@@ -244,7 +256,9 @@ class Qwen2Model(nn.Module):
|
||||
layer_id=idx,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,16 +66,29 @@ class Qwen2_5_VLMLP(nn.Module):
|
||||
bias: bool = True,
|
||||
hidden_act="silu",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = ColumnParallelLinear(
|
||||
in_features, hidden_features, bias=bias, quant_config=quant_config
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_proj", prefix),
|
||||
)
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
in_features, hidden_features, bias=bias, quant_config=quant_config
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
hidden_features, in_features, bias=bias, quant_config=quant_config
|
||||
hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
self.act = ACT2FN[hidden_act]
|
||||
|
||||
@@ -98,6 +112,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
attn_implementation: Optional[str] = "sdpa",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@@ -123,9 +138,14 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.mlp = Qwen2_5_VLMLP(
|
||||
dim, intermediate_dim, hidden_act=hidden_act, quant_config=quant_config
|
||||
dim,
|
||||
intermediate_dim,
|
||||
hidden_act=hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -178,6 +198,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
context_dim: int,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
@@ -189,10 +210,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp.0", prefix),
|
||||
),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(
|
||||
self.hidden_size, dim, bias=True, quant_config=quant_config
|
||||
self.hidden_size,
|
||||
dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp.2", prefix),
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -250,6 +276,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
vision_config: Qwen2_5_VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -286,8 +313,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
attn_implementation="sdpa",
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"blocks.{i}", prefix),
|
||||
)
|
||||
for _ in range(depth)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
@@ -295,6 +323,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
context_dim=hidden_size,
|
||||
spatial_merge_size=spatial_merge_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("merger", prefix),
|
||||
)
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
@@ -447,6 +476,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -457,15 +487,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
# NOTE: Qwen2-VL vision encoder does not support any
|
||||
# quantization method now.
|
||||
quant_config=None,
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
|
||||
self.model = Qwen2Model(config, quant_config)
|
||||
self.model = Qwen2Model(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=add_prefix("model", prefix),
|
||||
)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
||||
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
||||
@@ -42,7 +44,7 @@ class Qwen2DecoderLayer(Qwen2DecoderLayer):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, layer_id, quant_config)
|
||||
super().__init__(config, layer_id, quant_config, prefix=prefix)
|
||||
|
||||
# Skip the input_layernorm
|
||||
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
||||
@@ -56,6 +58,7 @@ class Qwen2Model(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -63,11 +66,15 @@ class Qwen2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2DecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -107,16 +114,22 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
@@ -56,10 +57,15 @@ class Qwen2MoeMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -67,6 +73,7 @@ class Qwen2MoeMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -87,6 +94,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -105,10 +113,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size, config.num_experts, bias=False, quant_config=None
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen2MoeMLP(
|
||||
@@ -117,6 +130,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_expert", prefix),
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
@@ -157,6 +171,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -188,6 +203,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -195,6 +211,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -210,6 +227,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -232,6 +250,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -247,6 +266,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
@@ -257,13 +277,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
if (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
):
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config)
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
self.mlp = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -300,6 +325,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -308,10 +334,16 @@ class Qwen2MoeModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
|
||||
Qwen2MoeDecoderLayer(
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -346,13 +378,19 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2MoeModel(config, quant_config)
|
||||
self.model = Qwen2MoeModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class Qwen2ForRewardModel(nn.Module):
|
||||
@@ -29,12 +30,15 @@ class Qwen2ForRewardModel(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.num_labels = 1
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.score = nn.Sequential(
|
||||
nn.Linear(config.hidden_size, config.hidden_size),
|
||||
nn.ReLU(),
|
||||
|
||||
@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -91,14 +92,21 @@ class Qwen2VisionMLP(nn.Module):
|
||||
hidden_features: int = None,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
in_features, hidden_features, quant_config=quant_config
|
||||
in_features,
|
||||
hidden_features,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("fc1", prefix),
|
||||
)
|
||||
self.act = act_layer()
|
||||
self.fc2 = RowParallelLinear(
|
||||
hidden_features, in_features, quant_config=quant_config
|
||||
hidden_features,
|
||||
in_features,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("fc2", prefix),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -119,6 +127,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
attn_implementation: Optional[str] = "sdpa",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@@ -145,9 +154,14 @@ class Qwen2VisionBlock(nn.Module):
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
self.mlp = Qwen2VisionMLP(
|
||||
dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
|
||||
dim,
|
||||
mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -199,6 +213,7 @@ class Qwen2VisionPatchMerger(nn.Module):
|
||||
norm_layer: Type[nn.Module] = None,
|
||||
spatial_merge_size: int = 2,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = context_dim * (spatial_merge_size**2)
|
||||
@@ -212,10 +227,15 @@ class Qwen2VisionPatchMerger(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp.0", prefix),
|
||||
),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(
|
||||
self.hidden_size, d_model, bias=True, quant_config=quant_config
|
||||
self.hidden_size,
|
||||
d_model,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp.2", prefix),
|
||||
),
|
||||
]
|
||||
)
|
||||
@@ -273,6 +293,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
vision_config: Qwen2VLVisionConfig,
|
||||
norm_eps: float = 1e-6,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -307,8 +328,9 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
attn_implementation="sdpa",
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"blocks.{i}", prefix),
|
||||
)
|
||||
for _ in range(depth)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2VisionPatchMerger(
|
||||
@@ -316,6 +338,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
context_dim=embed_dim,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("merger", prefix),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -440,6 +463,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -450,15 +474,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
# NOTE: Qwen2-VL vision encoder does not support any
|
||||
# quantization method now.
|
||||
quant_config=None,
|
||||
prefix=add_prefix("visual", prefix),
|
||||
)
|
||||
|
||||
self.model = Qwen2Model(config, quant_config)
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
@@ -49,6 +50,7 @@ class StablelmMLP(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -59,12 +61,14 @@ class StablelmMLP(nn.Module):
|
||||
[config.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@@ -81,6 +85,7 @@ class StablelmAttention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -122,11 +127,15 @@ class StablelmAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_key_value_heads,
|
||||
self.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -140,6 +149,7 @@ class StablelmAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -162,10 +172,15 @@ class StablelmDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = StablelmAttention(config, layer_id=layer_id)
|
||||
self.mlp = StablelmMLP(config, quant_config=quant_config)
|
||||
self.self_attn = StablelmAttention(
|
||||
config, layer_id=layer_id, prefix=add_prefix("self_attn", prefix)
|
||||
)
|
||||
self.mlp = StablelmMLP(
|
||||
config, quant_config=quant_config, prefix=add_prefix("mlp", prefix)
|
||||
)
|
||||
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
@@ -200,15 +215,22 @@ class StableLMEpochModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
StablelmDecoderLayer(config, i, quant_config=quant_config)
|
||||
StablelmDecoderLayer(
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -242,12 +264,17 @@ class StableLmForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.model = StableLMEpochModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -64,6 +64,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
@@ -294,14 +295,14 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_is_neox_style=rope_is_neox_style,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@@ -57,14 +58,14 @@ class XverseMLP(nn.Module):
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -128,14 +129,14 @@ class XverseAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -152,6 +153,7 @@ class XverseAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -202,14 +204,14 @@ class XverseDecoderLayer(nn.Module):
|
||||
rope_is_neox_style=rope_is_neox_style,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
self.mlp = XverseMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -246,6 +248,7 @@ class XverseModel(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -254,11 +257,15 @@ class XverseModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
XverseDecoderLayer(
|
||||
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -295,12 +302,17 @@ class XverseForCausalLM(nn.Module):
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = XverseModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.model = XverseModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@@ -54,10 +55,15 @@ class XverseMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -65,6 +71,7 @@ class XverseMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -86,6 +93,7 @@ class XverseMoE(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -107,14 +115,19 @@ class XverseMoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix(f"experts.{i}", prefix),
|
||||
)
|
||||
for _ in range(self.n_routed_experts)
|
||||
for i in range(self.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.pack_params()
|
||||
|
||||
self.router = ReplicatedLinear(
|
||||
config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
|
||||
config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("router", prefix),
|
||||
)
|
||||
|
||||
if config.num_shared_experts is not None:
|
||||
@@ -125,6 +138,7 @@ class XverseMoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
)
|
||||
|
||||
def pack_params(self):
|
||||
@@ -182,6 +196,7 @@ class XverseAttention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -213,6 +228,7 @@ class XverseAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -220,6 +236,7 @@ class XverseAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -235,6 +252,7 @@ class XverseAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -258,6 +276,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -276,15 +295,21 @@ class XverseDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
if config.num_experts is not None:
|
||||
self.mlp = XverseMoE(config=config, quant_config=quant_config)
|
||||
self.mlp = XverseMoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
self.mlp = XverseMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
@@ -324,6 +349,7 @@ class XverseModel(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -332,10 +358,16 @@ class XverseModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
XverseDecoderLayer(config, layer_id, quant_config=quant_config)
|
||||
XverseDecoderLayer(
|
||||
config,
|
||||
layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -364,13 +396,19 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = XverseModel(config, quant_config)
|
||||
self.model = XverseModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
|
||||
@@ -29,8 +29,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
self,
|
||||
config: LlavaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, quant_config)
|
||||
super().__init__(config, quant_config, prefix=prefix)
|
||||
|
||||
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
|
||||
|
||||
@@ -313,7 +313,7 @@ def make_layers(
|
||||
"""Make a list of layers with the given layer function"""
|
||||
modules = torch.nn.ModuleList(
|
||||
[
|
||||
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
|
||||
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
|
||||
for idx in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -1464,3 +1464,16 @@ def set_cuda_arch():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
arch = f"{capability[0]}.{capability[1]}"
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
|
||||
|
||||
|
||||
def add_prefix(name: str, prefix: str) -> str:
|
||||
"""Add a weight path prefix to a module name.
|
||||
|
||||
Args:
|
||||
name: base module name.
|
||||
prefix: weight prefix str to added to the front of `name` concatenated with `.`.
|
||||
|
||||
Returns:
|
||||
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
|
||||
"""
|
||||
return name if not prefix else f"{prefix}.{name}"
|
||||
|
||||
@@ -12,6 +12,7 @@ suites = {
|
||||
"models/test_generation_models.py",
|
||||
"models/test_qwen_models.py",
|
||||
"models/test_reward_models.py",
|
||||
"test_gptqmodel_dynamic.py",
|
||||
"test_abort.py",
|
||||
"test_chunked_prefill.py",
|
||||
"test_custom_allreduce.py",
|
||||
|
||||
211
test/srt/test_gptqmodel_dynamic.py
Normal file
211
test/srt/test_gptqmodel_dynamic.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def check_quant_method(model_path: str, use_marlin_kernel: bool):
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.quantization import get_dynamic_override
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
|
||||
try:
|
||||
init_distributed_environment(
|
||||
backend="nccl",
|
||||
world_size=1,
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
distributed_init_method="tcp://127.0.0.1:2646",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
monkey_patch_vllm_parallel_state()
|
||||
except AssertionError:
|
||||
# ignore this error: tensor model parallel group is already initialized
|
||||
pass
|
||||
|
||||
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
context_length=server_args.context_length,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
is_embedding=server_args.is_embedding,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
)
|
||||
|
||||
load_config = LoadConfig()
|
||||
device_config = DeviceConfig("cuda")
|
||||
model = get_model(
|
||||
model_config=model_config, load_config=load_config, device_config=device_config
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
linear_method_cls = (
|
||||
GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod)
|
||||
)
|
||||
|
||||
for name, submodule in model.named_modules():
|
||||
if name == "lm_head":
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
elif name == "model.layers.0.self_attn.qkv_proj":
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# desc_act=True
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.weight_bits == 4
|
||||
assert config.group_size == 128
|
||||
assert config.desc_act
|
||||
elif name == "model.layers.1.self_attn.qkv_proj":
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# desc_act=False
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert get_dynamic_override(config, layer_name=name, key="bits") == 8
|
||||
assert get_dynamic_override(config, layer_name=name, key="group_size") == 32
|
||||
assert not get_dynamic_override(config, layer_name=name, key="desc_act")
|
||||
elif (
|
||||
name == "model.layers.2.self_attn.qkv_proj"
|
||||
or name == "model.layers.2.mlp.gate_up_proj"
|
||||
):
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
|
||||
|
||||
del model
|
||||
|
||||
|
||||
# GPTQ with Dynamic Per/Module Quantization Control
|
||||
# Leverages GPTQModel (pypi) to produce the `dynamic` models
|
||||
# Test GPTQ fallback kernel that is not Marlin
|
||||
class TestGPTQModelDynamic(unittest.TestCase):
|
||||
MODEL_PATH = (
|
||||
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = cls.MODEL_PATH
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--dtype", "float16"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, max_new_tokens):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_throughput(self):
|
||||
max_tokens = 256
|
||||
|
||||
tic = time.time()
|
||||
result = self.run_decode(max_tokens)
|
||||
tok = time.time()
|
||||
|
||||
print(f"result = `{result}`")
|
||||
|
||||
assert "paris" in result["text"].lower()
|
||||
|
||||
throughput = max_tokens / (tok - tic)
|
||||
print(f"Throughput: {throughput} tokens/s")
|
||||
assert throughput >= 140
|
||||
|
||||
def test_gptq_module(self):
|
||||
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)
|
||||
|
||||
|
||||
# GPTQ with Dynamic Per/Module Quantization Control
|
||||
# Leverages GPTQModel (pypi) to produce the `dynamic` models
|
||||
# Test Marlin kernel
|
||||
class TestGPTQModelDynamicWithMarlin(unittest.TestCase):
|
||||
MODEL_PATH = (
|
||||
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = cls.MODEL_PATH
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--dtype", "float16"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, max_new_tokens):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_throughput(self):
|
||||
max_tokens = 256
|
||||
|
||||
tic = time.time()
|
||||
result = self.run_decode(max_tokens)
|
||||
tok = time.time()
|
||||
|
||||
print(f"result = `{result}`")
|
||||
|
||||
assert "paris" in result["text"].lower()
|
||||
|
||||
throughput = max_tokens / (tok - tic)
|
||||
print(f"Throughput: {throughput} tokens/s")
|
||||
assert throughput >= 140
|
||||
|
||||
def test_gptq_marlin_module(self):
|
||||
check_quant_method(self.MODEL_PATH, use_marlin_kernel=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user