321 lines
15 KiB
Markdown
321 lines
15 KiB
Markdown
|
|
# How to Support New Models
|
|||
|
|
|
|||
|
|
This document explains how to add support for new language models and multimodal large language models (MLLMs) in
|
|||
|
|
SGLang. It also covers how to test new models and register external implementations.
|
|||
|
|
|
|||
|
|
## How to Support a New Language Model
|
|||
|
|
|
|||
|
|
To support a new model in SGLang, you only need to add a single file under
|
|||
|
|
the [SGLang Models Directory](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/models). You can learn
|
|||
|
|
from existing model implementations and create a new file for your model. For most models, you should be able to find a
|
|||
|
|
similar model to start with (e.g., starting from Llama). Also refer how
|
|||
|
|
to [port a Model from vLLM to SGLang](#port-a-model-from-vllm-to-sglang)
|
|||
|
|
|
|||
|
|
## How to Support a New Multimodal Large Language Model
|
|||
|
|
|
|||
|
|
To support a new multimodal large language model (MLLM) in SGLang, there are several key components in addition to the
|
|||
|
|
standard LLM support:
|
|||
|
|
|
|||
|
|
1. **Register your new model as multimodal**:
|
|||
|
|
Extend `is_multimodal_model`
|
|||
|
|
in [model_config.py](https://github.com/sgl-project/sglang/blob/0ab3f437aba729b348a683ab32b35b214456efc7/python/sglang/srt/configs/model_config.py#L561)
|
|||
|
|
to return `True` for your model.
|
|||
|
|
|
|||
|
|
2. **Register a new chat-template**:
|
|||
|
|
Only when your default chat-template is unable to accept images as input: Register a new chat template in [conversation.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/conversation.py) and the corresponding matching function.
|
|||
|
|
|
|||
|
|
3. **Multimodal Data Processor**:
|
|||
|
|
Define a new `Processor` class that inherits from `BaseMultimodalProcessor` and register this processor as your
|
|||
|
|
model’s dedicated processor.
|
|||
|
|
See [multimodal_processor.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/multimodal/processors)
|
|||
|
|
for more details.
|
|||
|
|
|
|||
|
|
4. **Handle Multimodal Tokens**:
|
|||
|
|
Implement a `pad_input_ids` function for your new model. In this function, multimodal tokens in the prompt should be
|
|||
|
|
expanded (if necessary) and padded with multimodal-data-hashes so that SGLang can recognize different multimodal data
|
|||
|
|
with `RadixAttention`.
|
|||
|
|
|
|||
|
|
5. **Handle Image Feature Extraction**:
|
|||
|
|
Implement a `get_image_feature` function for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model.
|
|||
|
|
|
|||
|
|
6. **Adapt to Vision Attention**:
|
|||
|
|
Adapt the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.
|
|||
|
|
|
|||
|
|
You can refer to [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or
|
|||
|
|
other mllm implementations. These models demonstrate how to correctly handle both multimodal and textual inputs.
|
|||
|
|
|
|||
|
|
## Testing and Debugging
|
|||
|
|
|
|||
|
|
Please note all your testing and benchmarking results in PR description.
|
|||
|
|
|
|||
|
|
### Interactive Debugging
|
|||
|
|
|
|||
|
|
For interactive debugging, compare the outputs of Hugging Face/Transformers and SGLang. The following two commands
|
|||
|
|
should give the same text output and very similar prefill logits:
|
|||
|
|
|
|||
|
|
- Get the reference output:
|
|||
|
|
```bash
|
|||
|
|
python3 scripts/playground/reference_hf.py --model-path [new model] --model-type {text,mllm}
|
|||
|
|
```
|
|||
|
|
- Get the SGLang output:
|
|||
|
|
```bash
|
|||
|
|
python3 -m sglang.bench_one_batch --correct --model [new model]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Add the Model to the Test Suite
|
|||
|
|
|
|||
|
|
To ensure the new model is well maintained, add it to the test suite by including it in the `ALL_OTHER_MODELS` list in
|
|||
|
|
the [test_generation_models.py](https://github.com/sgl-project/sglang/blob/main/test/srt/models/test_generation_models.py)
|
|||
|
|
file, test the new model on your local machine and report the results on demonstrative benchmarks (GSM8K, MMLU, MMMU,
|
|||
|
|
MMMU-Pro, etc.) in your PR. \\
|
|||
|
|
For VLMs, also include a test in `test_vision_openai_server_{x}.py` (e.g. [test_vision_openai_server_a.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_a.py), [test_vision_openai_server_b.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_b.py)).
|
|||
|
|
|
|||
|
|
|
|||
|
|
This is an example command to run to test a new model on your local machine:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Benchmark
|
|||
|
|
|
|||
|
|
- **(Required) MMMU**: follow MMMU benchmark [README.md](https://github.com/sgl-project/sglang/blob/main/benchmark/mmmu/README.md) to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.ai/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer).
|
|||
|
|
- **(Optional) Other evals**: If you ran other evals, please note the results in PR description.
|
|||
|
|
|
|||
|
|
## Port a Model from vLLM to SGLang
|
|||
|
|
|
|||
|
|
The [vLLM Models Directory](https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models) is a valuable
|
|||
|
|
resource, as vLLM covers many models. SGLang reuses vLLM’s interface and some layers, making it easier to port models
|
|||
|
|
from vLLM to SGLang.
|
|||
|
|
|
|||
|
|
To port a model from vLLM to SGLang:
|
|||
|
|
|
|||
|
|
- Compare these two files for guidance:
|
|||
|
|
- [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py)
|
|||
|
|
- [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py)
|
|||
|
|
- The major differences include:
|
|||
|
|
- **Replace vLLM’s `Attention` with `RadixAttention`** (ensure you pass `layer_id` to `RadixAttention`).
|
|||
|
|
- **Replace vLLM’s `LogitsProcessor` with SGLang’s `LogitsProcessor`.**
|
|||
|
|
- **Replace the multi-headed `Attention` of ViT with SGLang’s `VisionAttention`.**
|
|||
|
|
- **Replace other vLLM layers** (such as `RMSNorm`, `SiluAndMul`) with SGLang layers.
|
|||
|
|
- **Remove `Sample`.**
|
|||
|
|
- **Change the `forward()` functions** and add a `forward_batch()` method.
|
|||
|
|
- **Add `EntryClass`** at the end.
|
|||
|
|
- **Ensure that the new implementation uses only SGLang components** and does not rely on any vLLM components.
|
|||
|
|
|
|||
|
|
Note: make sure you add your new model to the supported models list in the supported models documentation.
|
|||
|
|
|
|||
|
|
## Registering an External Model Implementation
|
|||
|
|
|
|||
|
|
In addition to the methods above, you can register your new model with the `ModelRegistry` before launching the server.
|
|||
|
|
This allows you to integrate your model without modifying the source code.
|
|||
|
|
|
|||
|
|
For example:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from sglang.srt.models.registry import ModelRegistry
|
|||
|
|
from sglang.srt.entrypoints.http_server import launch_server
|
|||
|
|
|
|||
|
|
# For a single model, add it to the registry:
|
|||
|
|
ModelRegistry.models[model_name] = model_class
|
|||
|
|
|
|||
|
|
# For multiple models, you can imitate the import_model_classes() function:
|
|||
|
|
from functools import lru_cache
|
|||
|
|
|
|||
|
|
@lru_cache()
|
|||
|
|
def import_new_model_classes():
|
|||
|
|
model_arch_name_to_cls = {}
|
|||
|
|
# Populate model_arch_name_to_cls with your new model classes.
|
|||
|
|
...
|
|||
|
|
return model_arch_name_to_cls
|
|||
|
|
|
|||
|
|
ModelRegistry.models.update(import_new_model_classes())
|
|||
|
|
|
|||
|
|
# Launch the server with your server arguments:
|
|||
|
|
launch_server(server_args)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Example: Implementing and Serving a Llama Wrapper Model
|
|||
|
|
|
|||
|
|
Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb).
|
|||
|
|
|
|||
|
|
### Implementing Our Model
|
|||
|
|
|
|||
|
|
To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit.
|
|||
|
|
|
|||
|
|
Let's start by defining our model in a file called `llama_wrapper.py`.
|
|||
|
|
The first step is to import the necessary libraries from SRT, which is SGLang's internal backend.
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# In the file `llama_wrapper.py`
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
from transformers import LlamaConfig
|
|||
|
|
from typing import Optional
|
|||
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|||
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|||
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|||
|
|
|
|||
|
|
from sglang.srt.models.llama import LlamaForCausalLM
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`.
|
|||
|
|
Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219).
|
|||
|
|
Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us.
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
class LlamaWrapper(LlamaForCausalLM):
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
config: LlamaConfig,
|
|||
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|||
|
|
prefix: str = "",
|
|||
|
|
) -> None:
|
|||
|
|
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Now, we want to define the `forward` method, which is what will be called at inference time.
|
|||
|
|
Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references.
|
|||
|
|
To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py).
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@torch.no_grad()
|
|||
|
|
def forward(
|
|||
|
|
self,
|
|||
|
|
input_ids: torch.Tensor,
|
|||
|
|
positions: torch.Tensor,
|
|||
|
|
forward_batch: ForwardBatch,
|
|||
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|||
|
|
input_embeds: Optional[torch.Tensor] = None,
|
|||
|
|
get_embedding: bool = False,
|
|||
|
|
) -> LogitsProcessorOutput:
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method.
|
|||
|
|
After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`).
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
hidden_states = self.model(
|
|||
|
|
input_ids,
|
|||
|
|
positions,
|
|||
|
|
forward_batch,
|
|||
|
|
input_embeds,
|
|||
|
|
pp_proxy_tensors=pp_proxy_tensors,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
res: LogitsProcessorOutput = self.logits_processor(
|
|||
|
|
input_ids,
|
|||
|
|
hidden_states,
|
|||
|
|
self.lm_head,
|
|||
|
|
forward_batch,
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
After receiving the logits for the next token, we can finally perform our biasing step.
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
orig_logits = res.next_token_logits
|
|||
|
|
res.next_token_logits = torch.where(
|
|||
|
|
orig_logits > 0,
|
|||
|
|
orig_logits.sqrt(),
|
|||
|
|
orig_logits
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return res
|
|||
|
|
```
|
|||
|
|
Now, our `LlamaWrapper` model is created and ready to be served!
|
|||
|
|
|
|||
|
|
### Serving Our Model Via SGLang's Offline Engine
|
|||
|
|
|
|||
|
|
The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.
|
|||
|
|
|
|||
|
|
First, create a new file called `run.py`.
|
|||
|
|
Now, we must ensure that SGLang's `ModelRegistry` can find our model.
|
|||
|
|
To do this, we first download the model's configuration and weights from Huggingface.
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# In the file `run.py`
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
from functools import lru_cache
|
|||
|
|
from huggingface_hub import snapshot_download
|
|||
|
|
from llama_wrapper import LlamaWrapper # Make sure to import our new model!
|
|||
|
|
import sglang as sgl
|
|||
|
|
from sglang.srt.models.registry import ModelRegistry
|
|||
|
|
|
|||
|
|
# Make sure to request access to this model on Huggingface, then export your
|
|||
|
|
# `HF_TOKEN` to download the model snapshot
|
|||
|
|
llama_dir = snapshot_download(
|
|||
|
|
repo_id="meta-llama/Llama-3.1-8B-Instruct",
|
|||
|
|
local_dir="./llama_ckpt",
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`.
|
|||
|
|
That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model.
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
{
|
|||
|
|
"architectures": [
|
|||
|
|
# "LlamaForCausalLM"
|
|||
|
|
"LlamaWrapper"
|
|||
|
|
],
|
|||
|
|
...
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model.
|
|||
|
|
Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation".
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@lru_cache()
|
|||
|
|
def import_new_model_classes():
|
|||
|
|
model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper}
|
|||
|
|
return model_arch_name_to_cls
|
|||
|
|
|
|||
|
|
ModelRegistry.models.update(import_new_model_classes())
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Lastly, when we create our `Engine`, we just pass in the path to the local model directory.
|
|||
|
|
Then, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint.
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def main():
|
|||
|
|
llm = sgl.Engine(model_path="./llama_ckpt")
|
|||
|
|
sampling_params = {"temperature": 0.2, "top_k": 5}
|
|||
|
|
prompts = [
|
|||
|
|
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
|
|||
|
|
"Provide a concise factual statement about France’s capital city. The capital of France is",
|
|||
|
|
"Explain possible future trends in artificial intelligence. The future of AI is",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
asyncio.run(run_llm(llm, sampling_params, prompts))
|
|||
|
|
|
|||
|
|
llm.shutdown()
|
|||
|
|
|
|||
|
|
async def run_llm(
|
|||
|
|
llm,
|
|||
|
|
sampling_params,
|
|||
|
|
prompts,
|
|||
|
|
) -> None:
|
|||
|
|
outputs = await llm.async_generate(prompts, sampling_params)
|
|||
|
|
|
|||
|
|
for prompt, output in zip(prompts, outputs):
|
|||
|
|
print(f"\nPrompt: {prompt}")
|
|||
|
|
print(f"Generated text: {output['text']}")
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Now, when we call `python run.py`, we will get the outputs of our newly created model!
|
|||
|
|
|
|||
|
|
|
|||
|
|
## Documentation
|
|||
|
|
Add to table of supported models in [generative_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/generative_models.md) or [multimodal_language_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/multimodal_language_models.md)
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
By following these guidelines, you can add support for new language models and multimodal large language models in
|
|||
|
|
SGLang and ensure they are thoroughly tested and easily integrated into the system.
|