[FEAT] Add transformers backend support (#5929)
This commit is contained in:
@@ -63,6 +63,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `kv_cache_dtype` | Dtype of the kv cache. | `auto` |
|
| `kv_cache_dtype` | Dtype of the kv cache. | `auto` |
|
||||||
| `context_length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). Note that extending the default might lead to strange behavior. | None |
|
| `context_length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). Note that extending the default might lead to strange behavior. | None |
|
||||||
| `device` | The device we put the model. | None |
|
| `device` | The device we put the model. | None |
|
||||||
|
| `impl` | The implementation of the model to use. Defaults to SGlang implementation and fall back to transformers if needed | `auto` |
|
||||||
| `served_model_name` | Override the model name returned by the v1/models endpoint in OpenAI API server.| None |
|
| `served_model_name` | Override the model name returned by the v1/models endpoint in OpenAI API server.| None |
|
||||||
| `is_embedding` | Set to `true` to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks. | `False` |
|
| `is_embedding` | Set to `true` to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks. | `False` |
|
||||||
| `revision` | Adjust if a specific version of the model should be used. | None |
|
| `revision` | Adjust if a specific version of the model should be used. | None |
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ The core features include:
|
|||||||
supported_models/embedding_models.md
|
supported_models/embedding_models.md
|
||||||
supported_models/reward_models.md
|
supported_models/reward_models.md
|
||||||
supported_models/support_new_models.md
|
supported_models/support_new_models.md
|
||||||
|
supported_models/transformers_fallback.md
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|||||||
58
docs/supported_models/transformers_fallback.md
Normal file
58
docs/supported_models/transformers_fallback.md
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Transformers fallback in SGLang
|
||||||
|
|
||||||
|
`sglang` can fall back to using models that are available in `transformers`. This works for most decoder-style language models and support for vision-language models is coming soon!
|
||||||
|
|
||||||
|
## Example launch Command
|
||||||
|
|
||||||
|
By default, we will use sglang implementation if it is available. Otherwise, we will fall back to transformers one. However, you can switch the implementation by setting `impl` to `transformers`.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python3 -m sglang.launch_server \
|
||||||
|
--model-path meta-llama/Llama-3.2-1B-Instruct \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 30000 \
|
||||||
|
--impl transformers
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supported features
|
||||||
|
|
||||||
|
##### Quantization
|
||||||
|
|
||||||
|
Transformers fall back has supported most of available quantization in SGLang (except GGUF). See [Quantization page](https://docs.sglang.ai/backend/quantization.html) for more information about supported quantization in SGLang.
|
||||||
|
|
||||||
|
##### Remote code
|
||||||
|
|
||||||
|
This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!
|
||||||
|
|
||||||
|
A model just needs the following two things:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class MyAttention(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, hidden_states, **kwargs): # <- kwargs are required
|
||||||
|
|
||||||
|
...
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
...
|
||||||
|
|
||||||
|
class MyModel(PreTrainedModel):
|
||||||
|
_supports_attention_backend = True
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is what happens in the background:
|
||||||
|
|
||||||
|
1. The config is loaded
|
||||||
|
2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||||
|
3. The `TransformersModel` backend is used. See `/srt/models/transformers`, which leverages `self.config._attn_implementation = "sglang"`, thus the need to use `ALL_ATTENTION_FUNCTIONS`.
|
||||||
|
|
||||||
|
That's it!
|
||||||
@@ -16,7 +16,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from typing import List, Optional, Set, Union
|
from typing import List, Optional, Set, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -39,6 +39,12 @@ class AttentionArch(IntEnum):
|
|||||||
MHA = auto()
|
MHA = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelImpl(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
SGLANG = "sglang"
|
||||||
|
TRANSFORMERS = "transformers"
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -53,11 +59,13 @@ class ModelConfig:
|
|||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
override_config_file: Optional[str] = None,
|
override_config_file: Optional[str] = None,
|
||||||
is_draft_model: bool = False,
|
is_draft_model: bool = False,
|
||||||
|
impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
|
self.impl = impl
|
||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
self.maybe_pull_model_tokenizer_from_remote()
|
self.maybe_pull_model_tokenizer_from_remote()
|
||||||
@@ -256,6 +264,7 @@ class ModelConfig:
|
|||||||
enable_multimodal=server_args.enable_multimodal,
|
enable_multimodal=server_args.enable_multimodal,
|
||||||
dtype=server_args.dtype,
|
dtype=server_args.dtype,
|
||||||
quantization=server_args.quantization,
|
quantization=server_args.quantization,
|
||||||
|
impl=server_args.impl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,17 @@
|
|||||||
|
|
||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading models."""
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import logging
|
||||||
from typing import Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig, ModelImpl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -19,6 +24,61 @@ def set_default_torch_dtype(dtype: torch.dtype):
|
|||||||
torch.set_default_dtype(old_dtype)
|
torch.set_default_dtype(old_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]):
|
||||||
|
for i, arch in enumerate(architectures):
|
||||||
|
if arch == "TransformersForCausalLM":
|
||||||
|
continue
|
||||||
|
auto_map: dict[str, str] = (
|
||||||
|
getattr(model_config.hf_config, "auto_map", None) or dict()
|
||||||
|
)
|
||||||
|
# Make sure that config class is always initialized before model class,
|
||||||
|
# otherwise the model class won't be able to access the config class,
|
||||||
|
# the expected auto_map should have correct order like:
|
||||||
|
# "auto_map": {
|
||||||
|
# "AutoConfig": "<your-repo-name>--<config-name>",
|
||||||
|
# "AutoModel": "<your-repo-name>--<config-name>",
|
||||||
|
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
|
||||||
|
# },
|
||||||
|
auto_modules = {
|
||||||
|
name: get_class_from_dynamic_module(
|
||||||
|
module, model_config.model_path, revision=model_config.revision
|
||||||
|
)
|
||||||
|
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
|
||||||
|
}
|
||||||
|
model_module = getattr(transformers, arch, None)
|
||||||
|
if model_module is None:
|
||||||
|
if "AutoModel" not in auto_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot find model module. '{arch}' is not a registered "
|
||||||
|
"model in the Transformers library (only relevant if the "
|
||||||
|
"model is meant to be in Transformers) and 'AutoModel' is "
|
||||||
|
"not present in the model config's 'auto_map' (relevant "
|
||||||
|
"if the model is custom)."
|
||||||
|
)
|
||||||
|
model_module = auto_modules["AutoModel"]
|
||||||
|
if model_config.impl == ModelImpl.TRANSFORMERS:
|
||||||
|
if not model_module.is_backend_compatible():
|
||||||
|
raise ValueError(
|
||||||
|
f"The Transformers implementation of {arch} is not "
|
||||||
|
"compatible with vLLM."
|
||||||
|
)
|
||||||
|
architectures[i] = "TransformersForCausalLM"
|
||||||
|
if model_config.impl == ModelImpl.AUTO:
|
||||||
|
if not model_module.is_backend_compatible():
|
||||||
|
raise ValueError(
|
||||||
|
f"{arch} has no SGlang implementation and the Transformers "
|
||||||
|
"implementation is not compatible with SGLang."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"%s has no SGLang implementation, falling back to Transformers "
|
||||||
|
"implementation. Some features may not be supported and "
|
||||||
|
"performance may not be optimal.",
|
||||||
|
arch,
|
||||||
|
)
|
||||||
|
architectures[i] = "TransformersForCausalLM"
|
||||||
|
return architectures
|
||||||
|
|
||||||
|
|
||||||
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||||
from sglang.srt.models.registry import ModelRegistry
|
from sglang.srt.models.registry import ModelRegistry
|
||||||
|
|
||||||
@@ -34,6 +94,12 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|||||||
):
|
):
|
||||||
architectures = ["QuantMixtralForCausalLM"]
|
architectures = ["QuantMixtralForCausalLM"]
|
||||||
|
|
||||||
|
supported_archs = ModelRegistry.get_supported_archs()
|
||||||
|
is_native_supported = any(arch in supported_archs for arch in architectures)
|
||||||
|
|
||||||
|
if not is_native_supported or model_config.impl == ModelImpl.TRANSFORMERS:
|
||||||
|
architectures = resolve_transformers_arch(model_config, architectures)
|
||||||
|
|
||||||
return ModelRegistry.resolve_model_cls(architectures)
|
return ModelRegistry.resolve_model_cls(architectures)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,15 @@ class _ModelRegistry:
|
|||||||
if not architectures:
|
if not architectures:
|
||||||
logger.warning("No model architectures are specified")
|
logger.warning("No model architectures are specified")
|
||||||
|
|
||||||
return architectures
|
# filter out support architectures
|
||||||
|
normalized_arch = list(
|
||||||
|
filter(lambda model: model in self.models, architectures)
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure Transformers backend is put at the last as a fallback
|
||||||
|
if len(normalized_arch) != len(architectures):
|
||||||
|
normalized_arch.append("TransformersForCausalLM")
|
||||||
|
return normalized_arch
|
||||||
|
|
||||||
def resolve_model_cls(
|
def resolve_model_cls(
|
||||||
self,
|
self,
|
||||||
|
|||||||
291
python/sglang/srt/models/transformers.py
Normal file
291
python/sglang/srt/models/transformers.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
# Copyright 2025 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/vllm-project/vllm/blob/a1a2aaadb9122f05667140e39cf67e5736c8b6d6/vllm/model_executor/models/transformers.py
|
||||||
|
"""Wrapper around `transformers` models"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Iterable, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
|
from sglang.srt.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
ReplicatedLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_prefix(prefix: str, name: str) -> str:
|
||||||
|
"""Add a prefix to a name if the prefix is non-empty.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: The prefix to add. If empty, no prefix will be added.
|
||||||
|
name: The name to potentially prefix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string "prefix.name" if prefix was non-empty, otherwise just "name".
|
||||||
|
"""
|
||||||
|
return name if not prefix else f"{prefix}.{name}"
|
||||||
|
|
||||||
|
|
||||||
|
def sglang_flash_attention_forward(
|
||||||
|
# Transformers args
|
||||||
|
module: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
# sglang kwargs
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
# Transformers kwargs
|
||||||
|
scaling: float = None,
|
||||||
|
attention_instances: list[RadixAttention] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self_attn: RadixAttention = attention_instances[module.layer_idx]
|
||||||
|
if scaling is not None:
|
||||||
|
self_attn.scaling = float(scaling)
|
||||||
|
hidden = query.shape[-2]
|
||||||
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||||
|
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
|
||||||
|
return self_attn.forward(query, key, value, forward_batch=forward_batch), None
|
||||||
|
|
||||||
|
|
||||||
|
ALL_ATTENTION_FUNCTIONS["sglang"] = sglang_flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
class HFColumnParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(input)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class HFRowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(input)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def replace_linear_class(
|
||||||
|
linear: nn.Linear,
|
||||||
|
style: Literal["colwise", "rowwise"],
|
||||||
|
quant_config: QuantizationConfig,
|
||||||
|
) -> Union[ColumnParallelLinear, RowParallelLinear]:
|
||||||
|
"""
|
||||||
|
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
linear (nn.Linear): `nn.Linear` to be replaced.
|
||||||
|
style (str): Tensor parallel style of the new linear, e.g. "colwise".
|
||||||
|
quant_config (QuantConfig): Quantization config for the new linear.
|
||||||
|
Returns:
|
||||||
|
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(style, str):
|
||||||
|
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||||
|
|
||||||
|
sglang_linear_cls = {
|
||||||
|
"colwise": ColumnParallelLinear,
|
||||||
|
"rowwise": RowParallelLinear,
|
||||||
|
}.get(style, ReplicatedLinear)
|
||||||
|
|
||||||
|
class HFCompatibleLinear(sglang_linear_cls):
|
||||||
|
"""
|
||||||
|
Wrapper class that removes `output_bias` from returned output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent_cls(self) -> type:
|
||||||
|
return sglang_linear_cls
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(input)[0]
|
||||||
|
|
||||||
|
return HFCompatibleLinear(
|
||||||
|
input_size=linear.in_features,
|
||||||
|
output_size=linear.out_features,
|
||||||
|
bias=linear.bias is not None,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logger.info("Using Transformers backend.")
|
||||||
|
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
# model is loaded under set_default_torch_dtype(model_config.dtype)
|
||||||
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
|
self.config,
|
||||||
|
torch_dtype=torch.get_default_dtype(),
|
||||||
|
attn_implementation="sglang",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention modifications (assumes 1 attention op per hidden layer)
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# MLP modifications
|
||||||
|
self.tensor_parallel(tp_size)
|
||||||
|
|
||||||
|
head_dim = (
|
||||||
|
(config.hidden_size // config.num_attention_heads)
|
||||||
|
if not hasattr(config, "head_dim")
|
||||||
|
else config.head_dim
|
||||||
|
)
|
||||||
|
self.attention_instances = [
|
||||||
|
RadixAttention(
|
||||||
|
num_heads=divide(config.num_attention_heads, tp_size),
|
||||||
|
head_dim=head_dim,
|
||||||
|
# NOTE: We use Llama scale as default, if it's set by
|
||||||
|
# Transformers, it's updated in sglang_flash_attention_forward
|
||||||
|
scaling=head_dim**-0.5,
|
||||||
|
num_kv_heads=divide(config.num_key_value_heads, tp_size),
|
||||||
|
layer_id=i,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
prefix=f"{i}.attn",
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Model modifications
|
||||||
|
self.replace_vocab_embed_class(self.model)
|
||||||
|
|
||||||
|
# ForCausalLM modifications
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight = self.model.get_input_embeddings().weight
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
def log_replacement(self, name: str, old_module: nn.Module, new_module: nn.Module):
|
||||||
|
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||||
|
|
||||||
|
def tensor_parallel(self, tp_size: int):
|
||||||
|
"""
|
||||||
|
Apply the model's tensor parallelization plan.
|
||||||
|
Currently only supports linear layers.
|
||||||
|
"""
|
||||||
|
if not self.model.supports_tp_plan:
|
||||||
|
if tp_size <= 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(self.model)} does not support tensor parallel yet!"
|
||||||
|
)
|
||||||
|
|
||||||
|
tp_plan = self.model._tp_plan
|
||||||
|
|
||||||
|
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
||||||
|
for child_name, child_module in module.named_children():
|
||||||
|
qual_name = maybe_prefix(prefix, child_name)
|
||||||
|
for pattern, style in tp_plan.items():
|
||||||
|
if re.match(pattern, qual_name) and isinstance(
|
||||||
|
child_module, nn.Linear
|
||||||
|
):
|
||||||
|
new_module = replace_linear_class(
|
||||||
|
child_module, style, self.quant_config
|
||||||
|
)
|
||||||
|
setattr(module, child_name, new_module)
|
||||||
|
self.log_replacement(qual_name, child_module, new_module)
|
||||||
|
else:
|
||||||
|
_tensor_parallel(child_module, prefix=qual_name)
|
||||||
|
|
||||||
|
_tensor_parallel(self.model)
|
||||||
|
|
||||||
|
def replace_vocab_embed_class(self, module: nn.Module):
|
||||||
|
# Use native set input embeddings
|
||||||
|
new_module = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
self.config.hidden_size,
|
||||||
|
org_num_embeddings=self.config.vocab_size,
|
||||||
|
quant_config=None,
|
||||||
|
)
|
||||||
|
self.log_replacement(
|
||||||
|
"input embedding", self.model.get_input_embeddings(), new_module
|
||||||
|
)
|
||||||
|
self.model.set_input_embeddings(new_module)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
input_embeds: torch.Tensor = None,
|
||||||
|
get_embedding: bool = False,
|
||||||
|
) -> LogitsProcessorOutput:
|
||||||
|
assert get_embedding is False, "embedding is not supported yet"
|
||||||
|
aux_hidden_states = None
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids[None, ...],
|
||||||
|
use_cache=False,
|
||||||
|
position_ids=positions[None, ...],
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
attention_instances=self.attention_instances,
|
||||||
|
return_dict=False,
|
||||||
|
)[0][
|
||||||
|
0, ...
|
||||||
|
] # we remove batch dimension for now
|
||||||
|
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if name not in params_dict:
|
||||||
|
name = f"{self.model.base_model_prefix}.{name}"
|
||||||
|
if name in params_dict:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [TransformersForCausalLM]
|
||||||
@@ -61,6 +61,7 @@ class ServerArgs:
|
|||||||
is_embedding: bool = False
|
is_embedding: bool = False
|
||||||
enable_multimodal: Optional[bool] = None
|
enable_multimodal: Optional[bool] = None
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
impl: str = "auto"
|
||||||
|
|
||||||
# Port for the HTTP server
|
# Port for the HTTP server
|
||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
@@ -726,6 +727,18 @@ class ServerArgs:
|
|||||||
default=ServerArgs.page_size,
|
default=ServerArgs.page_size,
|
||||||
help="The number of tokens in a page.",
|
help="The number of tokens in a page.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--impl",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.impl,
|
||||||
|
help="Which implementation of the model to use.\n\n"
|
||||||
|
'* "auto" will try to use the SGLang implementation if it exists '
|
||||||
|
"and fall back to the Transformers implementation if no SGLang "
|
||||||
|
"implementation is available.\n"
|
||||||
|
'* "sglang" will use the SGLang model implementation.\n'
|
||||||
|
'* "transformers" will use the Transformers model '
|
||||||
|
"implementation.\n",
|
||||||
|
)
|
||||||
|
|
||||||
# Other runtime options
|
# Other runtime options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -455,6 +455,7 @@ class SRTRunner:
|
|||||||
torch_dtype: torch.dtype,
|
torch_dtype: torch.dtype,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
|
impl: str = "auto",
|
||||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||||
lora_paths: List[str] = None,
|
lora_paths: List[str] = None,
|
||||||
max_loras_per_batch: int = 4,
|
max_loras_per_batch: int = 4,
|
||||||
@@ -475,6 +476,7 @@ class SRTRunner:
|
|||||||
speculative_num_draft_tokens: Optional[int] = None,
|
speculative_num_draft_tokens: Optional[int] = None,
|
||||||
disable_overlap_schedule: bool = False,
|
disable_overlap_schedule: bool = False,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
|
torchao_config: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
@@ -493,6 +495,8 @@ class SRTRunner:
|
|||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
dtype=get_dtype_str(torch_dtype),
|
dtype=get_dtype_str(torch_dtype),
|
||||||
port=port,
|
port=port,
|
||||||
|
impl=impl,
|
||||||
|
torchao_config=torchao_config,
|
||||||
mem_fraction_static=mem_fraction_static,
|
mem_fraction_static=mem_fraction_static,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
is_embedding=not self.is_generation,
|
is_embedding=not self.is_generation,
|
||||||
|
|||||||
181
test/srt/models/test_transformers_models.py
Normal file
181
test/srt/models/test_transformers_models.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
import dataclasses
|
||||||
|
import multiprocessing as mp
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner, check_close_model_outputs
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransformersFallbackEndpoint(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--impl", "transformers"],
|
||||||
|
)
|
||||||
|
cls.mmlu_lower_bound = 0.65
|
||||||
|
cls.gsm8k_lower_bound = 0.65
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreaterEqual(metrics["score"], self.mmlu_lower_bound)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], self.gsm8k_lower_bound)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransformersFallbackTorchAO(TestTransformersFallbackEndpoint):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--impl",
|
||||||
|
"transformers",
|
||||||
|
"--torchao-config",
|
||||||
|
"int4wo-128",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.mmlu_lower_bound = 0.65
|
||||||
|
cls.gsm8k_lower_bound = 0.65
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ModelCase:
|
||||||
|
model_path: str
|
||||||
|
tp_size: int = 1
|
||||||
|
prefill_tolerance: float = 5e-2
|
||||||
|
decode_tolerance: float = 5e-2
|
||||||
|
rouge_l_tolerance: float = 1
|
||||||
|
skip_long_prompt: bool = False
|
||||||
|
trust_remote_code: bool = False
|
||||||
|
torchao_config: str = None
|
||||||
|
torch_dtype: torch.dtype = torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
# Popular models that run on the CI
|
||||||
|
CI_MODELS = [
|
||||||
|
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST),
|
||||||
|
]
|
||||||
|
|
||||||
|
ALL_OTHER_MODELS = [
|
||||||
|
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST, tp_size=2),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransformersFallbackEngine(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
def assert_close_logits_and_output_strs(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
model_case: ModelCase,
|
||||||
|
) -> None:
|
||||||
|
model_path = model_case.model_path
|
||||||
|
max_new_tokens = 32
|
||||||
|
# force to use transformers impl
|
||||||
|
with SRTRunner(
|
||||||
|
model_path,
|
||||||
|
tp_size=model_case.tp_size,
|
||||||
|
torch_dtype=model_case.torch_dtype,
|
||||||
|
model_type="generation",
|
||||||
|
impl="transformers",
|
||||||
|
trust_remote_code=model_case.trust_remote_code,
|
||||||
|
torchao_config=model_case.torchao_config,
|
||||||
|
) as srt_runner:
|
||||||
|
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||||
|
|
||||||
|
with SRTRunner(
|
||||||
|
model_path,
|
||||||
|
tp_size=model_case.tp_size,
|
||||||
|
torch_dtype=model_case.torch_dtype,
|
||||||
|
model_type="generation",
|
||||||
|
trust_remote_code=model_case.trust_remote_code,
|
||||||
|
torchao_config=model_case.torchao_config,
|
||||||
|
) as srt_runner:
|
||||||
|
srt_transformers_outputs = srt_runner.forward(
|
||||||
|
prompts, max_new_tokens=max_new_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
check_close_model_outputs(
|
||||||
|
hf_outputs=srt_transformers_outputs,
|
||||||
|
srt_outputs=srt_outputs,
|
||||||
|
prefill_tolerance=model_case.prefill_tolerance,
|
||||||
|
decode_tolerance=model_case.decode_tolerance,
|
||||||
|
rouge_l_tolerance=model_case.rouge_l_tolerance,
|
||||||
|
debug_text=f"model_path={model_path} prompts={prompts}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_ci_models(self):
|
||||||
|
for model_case in CI_MODELS:
|
||||||
|
# Skip long prompts for models that do not have a long context
|
||||||
|
prompts = DEFAULT_PROMPTS
|
||||||
|
if model_case.skip_long_prompt:
|
||||||
|
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||||
|
# Assert the logits and output strs are close
|
||||||
|
self.assert_close_logits_and_output_strs(prompts, model_case)
|
||||||
|
|
||||||
|
def test_others(self):
|
||||||
|
if is_in_ci():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip long prompts for models that do not have a long context
|
||||||
|
prompts = DEFAULT_PROMPTS
|
||||||
|
for model_case in ALL_OTHER_MODELS:
|
||||||
|
if model_case.skip_long_prompt:
|
||||||
|
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||||
|
|
||||||
|
# Assert the logits and output strs are close
|
||||||
|
self.assert_close_logits_and_output_strs(prompts, model_case)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -26,6 +26,7 @@ suites = {
|
|||||||
TestFile("models/test_qwen_models.py", 82),
|
TestFile("models/test_qwen_models.py", 82),
|
||||||
TestFile("models/test_reward_models.py", 132),
|
TestFile("models/test_reward_models.py", 132),
|
||||||
TestFile("models/test_vlm_models.py", 437),
|
TestFile("models/test_vlm_models.py", 437),
|
||||||
|
TestFile("models/test_transformers_models.py", 320),
|
||||||
TestFile("test_abort.py", 51),
|
TestFile("test_abort.py", 51),
|
||||||
TestFile("test_block_int8.py", 22),
|
TestFile("test_block_int8.py", 22),
|
||||||
TestFile("test_create_kvindices.py", 2),
|
TestFile("test_create_kvindices.py", 2),
|
||||||
|
|||||||
Reference in New Issue
Block a user