From 37f1547587adbec1f7b582366425085452dcb10a Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 4 Jun 2025 06:05:29 +0200 Subject: [PATCH] [FEAT] Add transformers backend support (#5929) --- docs/backend/server_arguments.md | 1 + docs/index.rst | 1 + .../supported_models/transformers_fallback.md | 58 ++++ python/sglang/srt/configs/model_config.py | 11 +- python/sglang/srt/model_loader/utils.py | 68 +++- python/sglang/srt/models/registry.py | 10 +- python/sglang/srt/models/transformers.py | 291 ++++++++++++++++++ python/sglang/srt/server_args.py | 13 + python/sglang/test/runners.py | 4 + test/srt/models/test_transformers_models.py | 181 +++++++++++ test/srt/run_suite.py | 1 + 11 files changed, 636 insertions(+), 3 deletions(-) create mode 100644 docs/supported_models/transformers_fallback.md create mode 100644 python/sglang/srt/models/transformers.py create mode 100644 test/srt/models/test_transformers_models.py diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 50b888cbf..b36be6487 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -63,6 +63,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `kv_cache_dtype` | Dtype of the kv cache. | `auto` | | `context_length` | The model's maximum context length. Defaults to None (will use the value from the model's config.json instead). Note that extending the default might lead to strange behavior. | None | | `device` | The device we put the model. | None | +| `impl` | The implementation of the model to use. Defaults to SGlang implementation and fall back to transformers if needed | `auto` | | `served_model_name` | Override the model name returned by the v1/models endpoint in OpenAI API server.| None | | `is_embedding` | Set to `true` to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks. | `False` | | `revision` | Adjust if a specific version of the model should be used. | None | diff --git a/docs/index.rst b/docs/index.rst index 40f3b6655..d63455ba6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,7 @@ The core features include: supported_models/embedding_models.md supported_models/reward_models.md supported_models/support_new_models.md + supported_models/transformers_fallback.md .. toctree:: :maxdepth: 1 diff --git a/docs/supported_models/transformers_fallback.md b/docs/supported_models/transformers_fallback.md new file mode 100644 index 000000000..bca279fd3 --- /dev/null +++ b/docs/supported_models/transformers_fallback.md @@ -0,0 +1,58 @@ +# Transformers fallback in SGLang + +`sglang` can fall back to using models that are available in `transformers`. This works for most decoder-style language models and support for vision-language models is coming soon! + +## Example launch Command + +By default, we will use sglang implementation if it is available. Otherwise, we will fall back to transformers one. However, you can switch the implementation by setting `impl` to `transformers`. + +```shell +python3 -m sglang.launch_server \ + --model-path meta-llama/Llama-3.2-1B-Instruct \ + --host 0.0.0.0 \ + --port 30000 \ + --impl transformers +``` + +#### Supported features + +##### Quantization + +Transformers fall back has supported most of available quantization in SGLang (except GGUF). See [Quantization page](https://docs.sglang.ai/backend/quantization.html) for more information about supported quantization in SGLang. + +##### Remote code + +This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production! + +A model just needs the following two things: + +```python +from transformers import PreTrainedModel +from torch import nn + +class MyAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): # <- kwargs are required + + ... + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) + ... + +class MyModel(PreTrainedModel): + _supports_attention_backend = True +``` + +Here is what happens in the background: + +1. The config is loaded +2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. +3. The `TransformersModel` backend is used. See `/srt/models/transformers`, which leverages `self.config._attn_implementation = "sglang"`, thus the need to use `ALL_ATTENTION_FUNCTIONS`. + +That's it! diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 7a90cf413..0b641d344 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -16,7 +16,7 @@ import json import logging import math import os -from enum import IntEnum, auto +from enum import Enum, IntEnum, auto from typing import List, Optional, Set, Union import torch @@ -39,6 +39,12 @@ class AttentionArch(IntEnum): MHA = auto() +class ModelImpl(str, Enum): + AUTO = "auto" + SGLANG = "sglang" + TRANSFORMERS = "transformers" + + class ModelConfig: def __init__( self, @@ -53,11 +59,13 @@ class ModelConfig: quantization: Optional[str] = None, override_config_file: Optional[str] = None, is_draft_model: bool = False, + impl: Union[str, ModelImpl] = ModelImpl.AUTO, ) -> None: self.model_path = model_path self.revision = revision self.quantization = quantization + self.impl = impl # Parse args self.maybe_pull_model_tokenizer_from_remote() @@ -256,6 +264,7 @@ class ModelConfig: enable_multimodal=server_args.enable_multimodal, dtype=server_args.dtype, quantization=server_args.quantization, + impl=server_args.impl, **kwargs, ) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py index daad1e67f..4f65ad5fe 100644 --- a/python/sglang/srt/model_loader/utils.py +++ b/python/sglang/srt/model_loader/utils.py @@ -2,12 +2,17 @@ """Utilities for selecting and loading models.""" import contextlib +import logging from typing import Tuple, Type import torch +import transformers from torch import nn +from transformers.dynamic_module_utils import get_class_from_dynamic_module -from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.configs.model_config import ModelConfig, ModelImpl + +logger = logging.getLogger(__name__) @contextlib.contextmanager @@ -19,6 +24,61 @@ def set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) +def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): + for i, arch in enumerate(architectures): + if arch == "TransformersForCausalLM": + continue + auto_map: dict[str, str] = ( + getattr(model_config.hf_config, "auto_map", None) or dict() + ) + # Make sure that config class is always initialized before model class, + # otherwise the model class won't be able to access the config class, + # the expected auto_map should have correct order like: + # "auto_map": { + # "AutoConfig": "--", + # "AutoModel": "--", + # "AutoModelFor": "--", + # }, + auto_modules = { + name: get_class_from_dynamic_module( + module, model_config.model_path, revision=model_config.revision + ) + for name, module in sorted(auto_map.items(), key=lambda x: x[0]) + } + model_module = getattr(transformers, arch, None) + if model_module is None: + if "AutoModel" not in auto_map: + raise ValueError( + f"Cannot find model module. '{arch}' is not a registered " + "model in the Transformers library (only relevant if the " + "model is meant to be in Transformers) and 'AutoModel' is " + "not present in the model config's 'auto_map' (relevant " + "if the model is custom)." + ) + model_module = auto_modules["AutoModel"] + if model_config.impl == ModelImpl.TRANSFORMERS: + if not model_module.is_backend_compatible(): + raise ValueError( + f"The Transformers implementation of {arch} is not " + "compatible with vLLM." + ) + architectures[i] = "TransformersForCausalLM" + if model_config.impl == ModelImpl.AUTO: + if not model_module.is_backend_compatible(): + raise ValueError( + f"{arch} has no SGlang implementation and the Transformers " + "implementation is not compatible with SGLang." + ) + logger.warning( + "%s has no SGLang implementation, falling back to Transformers " + "implementation. Some features may not be supported and " + "performance may not be optimal.", + arch, + ) + architectures[i] = "TransformersForCausalLM" + return architectures + + def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: from sglang.srt.models.registry import ModelRegistry @@ -34,6 +94,12 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], ): architectures = ["QuantMixtralForCausalLM"] + supported_archs = ModelRegistry.get_supported_archs() + is_native_supported = any(arch in supported_archs for arch in architectures) + + if not is_native_supported or model_config.impl == ModelImpl.TRANSFORMERS: + architectures = resolve_transformers_arch(model_config, architectures) + return ModelRegistry.resolve_model_cls(architectures) diff --git a/python/sglang/srt/models/registry.py b/python/sglang/srt/models/registry.py index fc63bf125..f81d3c76e 100644 --- a/python/sglang/srt/models/registry.py +++ b/python/sglang/srt/models/registry.py @@ -49,7 +49,15 @@ class _ModelRegistry: if not architectures: logger.warning("No model architectures are specified") - return architectures + # filter out support architectures + normalized_arch = list( + filter(lambda model: model in self.models, architectures) + ) + + # make sure Transformers backend is put at the last as a fallback + if len(normalized_arch) != len(architectures): + normalized_arch.append("TransformersForCausalLM") + return normalized_arch def resolve_model_cls( self, diff --git a/python/sglang/srt/models/transformers.py b/python/sglang/srt/models/transformers.py new file mode 100644 index 000000000..9ee2a14b2 --- /dev/null +++ b/python/sglang/srt/models/transformers.py @@ -0,0 +1,291 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from +# https://github.com/vllm-project/vllm/blob/a1a2aaadb9122f05667140e39cf67e5736c8b6d6/vllm/model_executor/models/transformers.py +"""Wrapper around `transformers` models""" +import logging +import re +from typing import Iterable, Literal, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import AutoModel, PretrainedConfig, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader + +logger = logging.getLogger(__name__) + + +def maybe_prefix(prefix: str, name: str) -> str: + """Add a prefix to a name if the prefix is non-empty. + + Args: + prefix: The prefix to add. If empty, no prefix will be added. + name: The name to potentially prefix. + + Returns: + The string "prefix.name" if prefix was non-empty, otherwise just "name". + """ + return name if not prefix else f"{prefix}.{name}" + + +def sglang_flash_attention_forward( + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # sglang kwargs + forward_batch: ForwardBatch, + # Transformers kwargs + scaling: float = None, + attention_instances: list[RadixAttention] = None, + **kwargs, +): + self_attn: RadixAttention = attention_instances[module.layer_idx] + if scaling is not None: + self_attn.scaling = float(scaling) + hidden = query.shape[-2] + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) + return self_attn.forward(query, key, value, forward_batch=forward_batch), None + + +ALL_ATTENTION_FUNCTIONS["sglang"] = sglang_flash_attention_forward + + +class HFColumnParallelLinear(ColumnParallelLinear): + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input)[0] + + +class HFRowParallelLinear(RowParallelLinear): + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input)[0] + + +def replace_linear_class( + linear: nn.Linear, + style: Literal["colwise", "rowwise"], + quant_config: QuantizationConfig, +) -> Union[ColumnParallelLinear, RowParallelLinear]: + """ + Replace nn.Linear with one of vLLM's tensor parallel linear classes. + + Args: + linear (nn.Linear): `nn.Linear` to be replaced. + style (str): Tensor parallel style of the new linear, e.g. "colwise". + quant_config (QuantConfig): Quantization config for the new linear. + Returns: + Union[ColumnParallelLinear, RowParallelLinear]: The new linear. + """ + + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + sglang_linear_cls = { + "colwise": ColumnParallelLinear, + "rowwise": RowParallelLinear, + }.get(style, ReplicatedLinear) + + class HFCompatibleLinear(sglang_linear_cls): + """ + Wrapper class that removes `output_bias` from returned output. + """ + + @property + def parent_cls(self) -> type: + return sglang_linear_cls + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input)[0] + + return HFCompatibleLinear( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + quant_config=quant_config, + ) + + +class TransformersForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + logger.info("Using Transformers backend.") + + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + # model is loaded under set_default_torch_dtype(model_config.dtype) + self.model: PreTrainedModel = AutoModel.from_config( + self.config, + torch_dtype=torch.get_default_dtype(), + attn_implementation="sglang", + trust_remote_code=True, + ) + + # Attention modifications (assumes 1 attention op per hidden layer) + tp_size = get_tensor_model_parallel_world_size() + + # MLP modifications + self.tensor_parallel(tp_size) + + head_dim = ( + (config.hidden_size // config.num_attention_heads) + if not hasattr(config, "head_dim") + else config.head_dim + ) + self.attention_instances = [ + RadixAttention( + num_heads=divide(config.num_attention_heads, tp_size), + head_dim=head_dim, + # NOTE: We use Llama scale as default, if it's set by + # Transformers, it's updated in sglang_flash_attention_forward + scaling=head_dim**-0.5, + num_kv_heads=divide(config.num_key_value_heads, tp_size), + layer_id=i, + quant_config=self.quant_config, + prefix=f"{i}.attn", + ) + for i in range(config.num_hidden_layers) + ] + + # Model modifications + self.replace_vocab_embed_class(self.model) + + # ForCausalLM modifications + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.get_input_embeddings().weight + + self.logits_processor = LogitsProcessor(config) + + def log_replacement(self, name: str, old_module: nn.Module, new_module: nn.Module): + logger.debug("%s: %s -> %s", name, old_module, new_module) + + def tensor_parallel(self, tp_size: int): + """ + Apply the model's tensor parallelization plan. + Currently only supports linear layers. + """ + if not self.model.supports_tp_plan: + if tp_size <= 1: + return + + raise ValueError( + f"{type(self.model)} does not support tensor parallel yet!" + ) + + tp_plan = self.model._tp_plan + + def _tensor_parallel(module: nn.Module, prefix: str = ""): + for child_name, child_module in module.named_children(): + qual_name = maybe_prefix(prefix, child_name) + for pattern, style in tp_plan.items(): + if re.match(pattern, qual_name) and isinstance( + child_module, nn.Linear + ): + new_module = replace_linear_class( + child_module, style, self.quant_config + ) + setattr(module, child_name, new_module) + self.log_replacement(qual_name, child_module, new_module) + else: + _tensor_parallel(child_module, prefix=qual_name) + + _tensor_parallel(self.model) + + def replace_vocab_embed_class(self, module: nn.Module): + # Use native set input embeddings + new_module = VocabParallelEmbedding( + self.vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + quant_config=None, + ) + self.log_replacement( + "input embedding", self.model.get_input_embeddings(), new_module + ) + self.model.set_input_embeddings(new_module) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + assert get_embedding is False, "embedding is not supported yet" + aux_hidden_states = None + hidden_states = self.model( + input_ids[None, ...], + use_cache=False, + position_ids=positions[None, ...], + forward_batch=forward_batch, + attention_instances=self.attention_instances, + return_dict=False, + )[0][ + 0, ... + ] # we remove batch dimension for now + + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name not in params_dict: + name = f"{self.model.base_model_prefix}.{name}" + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [TransformersForCausalLM] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ab48627d0..97bcad86f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -61,6 +61,7 @@ class ServerArgs: is_embedding: bool = False enable_multimodal: Optional[bool] = None revision: Optional[str] = None + impl: str = "auto" # Port for the HTTP server host: str = "127.0.0.1" @@ -726,6 +727,18 @@ class ServerArgs: default=ServerArgs.page_size, help="The number of tokens in a page.", ) + parser.add_argument( + "--impl", + type=str, + default=ServerArgs.impl, + help="Which implementation of the model to use.\n\n" + '* "auto" will try to use the SGLang implementation if it exists ' + "and fall back to the Transformers implementation if no SGLang " + "implementation is available.\n" + '* "sglang" will use the SGLang model implementation.\n' + '* "transformers" will use the Transformers model ' + "implementation.\n", + ) # Other runtime options parser.add_argument( diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 5318e5206..f5c8365a7 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -455,6 +455,7 @@ class SRTRunner: torch_dtype: torch.dtype, model_type: str, tp_size: int = 1, + impl: str = "auto", port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: List[str] = None, max_loras_per_batch: int = 4, @@ -475,6 +476,7 @@ class SRTRunner: speculative_num_draft_tokens: Optional[int] = None, disable_overlap_schedule: bool = False, disable_custom_all_reduce: bool = False, + torchao_config: Optional[str] = None, ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -493,6 +495,8 @@ class SRTRunner: tp_size=tp_size, dtype=get_dtype_str(torch_dtype), port=port, + impl=impl, + torchao_config=torchao_config, mem_fraction_static=mem_fraction_static, trust_remote_code=trust_remote_code, is_embedding=not self.is_generation, diff --git a/test/srt/models/test_transformers_models.py b/test/srt/models/test_transformers_models.py new file mode 100644 index 000000000..7e92b49d1 --- /dev/null +++ b/test/srt/models/test_transformers_models.py @@ -0,0 +1,181 @@ +import dataclasses +import multiprocessing as mp +import unittest +from types import SimpleNamespace +from typing import List + +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner, check_close_model_outputs +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + + +class TestTransformersFallbackEndpoint(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--impl", "transformers"], + ) + cls.mmlu_lower_bound = 0.65 + cls.gsm8k_lower_bound = 0.65 + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + from sglang.test.run_eval import run_eval + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], self.mmlu_lower_bound) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + from sglang.test.few_shot_gsm8k import run_eval + + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], self.gsm8k_lower_bound) + + +class TestTransformersFallbackTorchAO(TestTransformersFallbackEndpoint): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--impl", + "transformers", + "--torchao-config", + "int4wo-128", + ], + ) + cls.mmlu_lower_bound = 0.65 + cls.gsm8k_lower_bound = 0.65 + + +@dataclasses.dataclass +class ModelCase: + model_path: str + tp_size: int = 1 + prefill_tolerance: float = 5e-2 + decode_tolerance: float = 5e-2 + rouge_l_tolerance: float = 1 + skip_long_prompt: bool = False + trust_remote_code: bool = False + torchao_config: str = None + torch_dtype: torch.dtype = torch.float16 + + +# Popular models that run on the CI +CI_MODELS = [ + ModelCase(DEFAULT_MODEL_NAME_FOR_TEST), +] + +ALL_OTHER_MODELS = [ + ModelCase(DEFAULT_MODEL_NAME_FOR_TEST, tp_size=2), +] + + +class TestTransformersFallbackEngine(CustomTestCase): + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn", force=True) + + def assert_close_logits_and_output_strs( + self, + prompts: List[str], + model_case: ModelCase, + ) -> None: + model_path = model_case.model_path + max_new_tokens = 32 + # force to use transformers impl + with SRTRunner( + model_path, + tp_size=model_case.tp_size, + torch_dtype=model_case.torch_dtype, + model_type="generation", + impl="transformers", + trust_remote_code=model_case.trust_remote_code, + torchao_config=model_case.torchao_config, + ) as srt_runner: + srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) + + with SRTRunner( + model_path, + tp_size=model_case.tp_size, + torch_dtype=model_case.torch_dtype, + model_type="generation", + trust_remote_code=model_case.trust_remote_code, + torchao_config=model_case.torchao_config, + ) as srt_runner: + srt_transformers_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + check_close_model_outputs( + hf_outputs=srt_transformers_outputs, + srt_outputs=srt_outputs, + prefill_tolerance=model_case.prefill_tolerance, + decode_tolerance=model_case.decode_tolerance, + rouge_l_tolerance=model_case.rouge_l_tolerance, + debug_text=f"model_path={model_path} prompts={prompts}", + ) + + def test_ci_models(self): + for model_case in CI_MODELS: + # Skip long prompts for models that do not have a long context + prompts = DEFAULT_PROMPTS + if model_case.skip_long_prompt: + prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] + # Assert the logits and output strs are close + self.assert_close_logits_and_output_strs(prompts, model_case) + + def test_others(self): + if is_in_ci(): + return + + # Skip long prompts for models that do not have a long context + prompts = DEFAULT_PROMPTS + for model_case in ALL_OTHER_MODELS: + if model_case.skip_long_prompt: + prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] + + # Assert the logits and output strs are close + self.assert_close_logits_and_output_strs(prompts, model_case) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 323aeb1eb..197cf3349 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -26,6 +26,7 @@ suites = { TestFile("models/test_qwen_models.py", 82), TestFile("models/test_reward_models.py", 132), TestFile("models/test_vlm_models.py", 437), + TestFile("models/test_transformers_models.py", 320), TestFile("test_abort.py", 51), TestFile("test_block_int8.py", 22), TestFile("test_create_kvindices.py", 2),