[FEAT] Add transformers backend support (#5929)

This commit is contained in:
Marc Sun
2025-06-04 06:05:29 +02:00
committed by GitHub
parent 8a5480528d
commit 37f1547587
11 changed files with 636 additions and 3 deletions

View File

@@ -16,7 +16,7 @@ import json
import logging
import math
import os
from enum import IntEnum, auto
from enum import Enum, IntEnum, auto
from typing import List, Optional, Set, Union
import torch
@@ -39,6 +39,12 @@ class AttentionArch(IntEnum):
MHA = auto()
class ModelImpl(str, Enum):
AUTO = "auto"
SGLANG = "sglang"
TRANSFORMERS = "transformers"
class ModelConfig:
def __init__(
self,
@@ -53,11 +59,13 @@ class ModelConfig:
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.impl = impl
# Parse args
self.maybe_pull_model_tokenizer_from_remote()
@@ -256,6 +264,7 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
impl=server_args.impl,
**kwargs,
)

View File

@@ -2,12 +2,17 @@
"""Utilities for selecting and loading models."""
import contextlib
import logging
from typing import Tuple, Type
import torch
import transformers
from torch import nn
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.configs.model_config import ModelConfig, ModelImpl
logger = logging.getLogger(__name__)
@contextlib.contextmanager
@@ -19,6 +24,61 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)
def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM":
continue
auto_map: dict[str, str] = (
getattr(model_config.hf_config, "auto_map", None) or dict()
)
# Make sure that config class is always initialized before model class,
# otherwise the model class won't be able to access the config class,
# the expected auto_map should have correct order like:
# "auto_map": {
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name: get_class_from_dynamic_module(
module, model_config.model_path, revision=model_config.revision
)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
model_module = getattr(transformers, arch, None)
if model_module is None:
if "AutoModel" not in auto_map:
raise ValueError(
f"Cannot find model module. '{arch}' is not a registered "
"model in the Transformers library (only relevant if the "
"model is meant to be in Transformers) and 'AutoModel' is "
"not present in the model config's 'auto_map' (relevant "
"if the model is custom)."
)
model_module = auto_modules["AutoModel"]
if model_config.impl == ModelImpl.TRANSFORMERS:
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM."
)
architectures[i] = "TransformersForCausalLM"
if model_config.impl == ModelImpl.AUTO:
if not model_module.is_backend_compatible():
raise ValueError(
f"{arch} has no SGlang implementation and the Transformers "
"implementation is not compatible with SGLang."
)
logger.warning(
"%s has no SGLang implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.",
arch,
)
architectures[i] = "TransformersForCausalLM"
return architectures
def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
from sglang.srt.models.registry import ModelRegistry
@@ -34,6 +94,12 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
):
architectures = ["QuantMixtralForCausalLM"]
supported_archs = ModelRegistry.get_supported_archs()
is_native_supported = any(arch in supported_archs for arch in architectures)
if not is_native_supported or model_config.impl == ModelImpl.TRANSFORMERS:
architectures = resolve_transformers_arch(model_config, architectures)
return ModelRegistry.resolve_model_cls(architectures)

View File

@@ -49,7 +49,15 @@ class _ModelRegistry:
if not architectures:
logger.warning("No model architectures are specified")
return architectures
# filter out support architectures
normalized_arch = list(
filter(lambda model: model in self.models, architectures)
)
# make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersForCausalLM")
return normalized_arch
def resolve_model_cls(
self,

View File

@@ -0,0 +1,291 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/a1a2aaadb9122f05667140e39cf67e5736c8b6d6/vllm/model_executor/models/transformers.py
"""Wrapper around `transformers` models"""
import logging
import re
from typing import Iterable, Literal, Optional, Tuple, Union
import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
logger = logging.getLogger(__name__)
def maybe_prefix(prefix: str, name: str) -> str:
"""Add a prefix to a name if the prefix is non-empty.
Args:
prefix: The prefix to add. If empty, no prefix will be added.
name: The name to potentially prefix.
Returns:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
return name if not prefix else f"{prefix}.{name}"
def sglang_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# sglang kwargs
forward_batch: ForwardBatch,
# Transformers kwargs
scaling: float = None,
attention_instances: list[RadixAttention] = None,
**kwargs,
):
self_attn: RadixAttention = attention_instances[module.layer_idx]
if scaling is not None:
self_attn.scaling = float(scaling)
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward(query, key, value, forward_batch=forward_batch), None
ALL_ATTENTION_FUNCTIONS["sglang"] = sglang_flash_attention_forward
class HFColumnParallelLinear(ColumnParallelLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
class HFRowParallelLinear(RowParallelLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
def replace_linear_class(
linear: nn.Linear,
style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig,
) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
sglang_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)
class HFCompatibleLinear(sglang_linear_cls):
"""
Wrapper class that removes `output_bias` from returned output.
"""
@property
def parent_cls(self) -> type:
return sglang_linear_cls
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
return HFCompatibleLinear(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
)
class TransformersForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
logger.info("Using Transformers backend.")
self.quant_config = quant_config
self.config = config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
# model is loaded under set_default_torch_dtype(model_config.dtype)
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
torch_dtype=torch.get_default_dtype(),
attn_implementation="sglang",
trust_remote_code=True,
)
# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
# MLP modifications
self.tensor_parallel(tp_size)
head_dim = (
(config.hidden_size // config.num_attention_heads)
if not hasattr(config, "head_dim")
else config.head_dim
)
self.attention_instances = [
RadixAttention(
num_heads=divide(config.num_attention_heads, tp_size),
head_dim=head_dim,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in sglang_flash_attention_forward
scaling=head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
layer_id=i,
quant_config=self.quant_config,
prefix=f"{i}.attn",
)
for i in range(config.num_hidden_layers)
]
# Model modifications
self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight
self.logits_processor = LogitsProcessor(config)
def log_replacement(self, name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def tensor_parallel(self, tp_size: int):
"""
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if not self.model.supports_tp_plan:
if tp_size <= 1:
return
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!"
)
tp_plan = self.model._tp_plan
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear
):
new_module = replace_linear_class(
child_module, style, self.quant_config
)
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module, prefix=qual_name)
_tensor_parallel(self.model)
def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
quant_config=None,
)
self.log_replacement(
"input embedding", self.model.get_input_embeddings(), new_module
)
self.model.set_input_embeddings(new_module)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> LogitsProcessorOutput:
assert get_embedding is False, "embedding is not supported yet"
aux_hidden_states = None
hidden_states = self.model(
input_ids[None, ...],
use_cache=False,
position_ids=positions[None, ...],
forward_batch=forward_batch,
attention_instances=self.attention_instances,
return_dict=False,
)[0][
0, ...
] # we remove batch dimension for now
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = [TransformersForCausalLM]

View File

@@ -61,6 +61,7 @@ class ServerArgs:
is_embedding: bool = False
enable_multimodal: Optional[bool] = None
revision: Optional[str] = None
impl: str = "auto"
# Port for the HTTP server
host: str = "127.0.0.1"
@@ -726,6 +727,18 @@ class ServerArgs:
default=ServerArgs.page_size,
help="The number of tokens in a page.",
)
parser.add_argument(
"--impl",
type=str,
default=ServerArgs.impl,
help="Which implementation of the model to use.\n\n"
'* "auto" will try to use the SGLang implementation if it exists '
"and fall back to the Transformers implementation if no SGLang "
"implementation is available.\n"
'* "sglang" will use the SGLang model implementation.\n'
'* "transformers" will use the Transformers model '
"implementation.\n",
)
# Other runtime options
parser.add_argument(

View File

@@ -455,6 +455,7 @@ class SRTRunner:
torch_dtype: torch.dtype,
model_type: str,
tp_size: int = 1,
impl: str = "auto",
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None,
max_loras_per_batch: int = 4,
@@ -475,6 +476,7 @@ class SRTRunner:
speculative_num_draft_tokens: Optional[int] = None,
disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False,
torchao_config: Optional[str] = None,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
@@ -493,6 +495,8 @@ class SRTRunner:
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
impl=impl,
torchao_config=torchao_config,
mem_fraction_static=mem_fraction_static,
trust_remote_code=trust_remote_code,
is_embedding=not self.is_generation,