in [model_config.py](https://github.com/sgl-project/sglang/blob/0ab3f437aba729b348a683ab32b35b214456efc7/python/sglang/srt/configs/model_config.py#L561)
Only when your default chat-template is unable to accept images as input: Register a new chat template in [conversation.py](https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/conversation.py) and the corresponding matching function.
Implement a `get_image_feature` function for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model.
For VLMs, also include a test in `test_vision_openai_server_{x}.py` (e.g. [test_vision_openai_server_a.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_a.py), [test_vision_openai_server_b.py](https://github.com/sgl-project/sglang/blob/main/test/srt/test_vision_openai_server_b.py)).
- **(Required) MMMU**: follow MMMU benchmark [README.md](https://github.com/sgl-project/sglang/blob/main/benchmark/mmmu/README.md) to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.ai/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer).
- **(Optional) Other evals**: If you ran other evals, please note the results in PR description.
## Example: Implementing and Serving a Llama Wrapper Model
Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb).
### Implementing Our Model
To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit.
Let's start by defining our model in a file called `llama_wrapper.py`.
The first step is to import the necessary libraries from SRT, which is SGLang's internal backend.
```python
# In the file `llama_wrapper.py`
import torch
from transformers import LlamaConfig
from typing import Optional
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.llama import LlamaForCausalLM
```
Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`.
Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219).
Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us.
Now, we want to define the `forward` method, which is what will be called at inference time.
Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references.
To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py).
We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method.
After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`).
After receiving the logits for the next token, we can finally perform our biasing step.
```python
orig_logits = res.next_token_logits
res.next_token_logits = torch.where(
orig_logits > 0,
orig_logits.sqrt(),
orig_logits
)
return res
```
Now, our `LlamaWrapper` model is created and ready to be served!
### Serving Our Model Via SGLang's Offline Engine
The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.
First, create a new file called `run.py`.
Now, we must ensure that SGLang's `ModelRegistry` can find our model.
To do this, we first download the model's configuration and weights from Huggingface.
```python
# In the file `run.py`
import asyncio
from functools import lru_cache
from huggingface_hub import snapshot_download
from llama_wrapper import LlamaWrapper # Make sure to import our new model!
import sglang as sgl
from sglang.srt.models.registry import ModelRegistry
# Make sure to request access to this model on Huggingface, then export your
# `HF_TOKEN` to download the model snapshot
llama_dir = snapshot_download(
repo_id="meta-llama/Llama-3.1-8B-Instruct",
local_dir="./llama_ckpt",
)
```
Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`.
That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model.
```python
{
"architectures": [
# "LlamaForCausalLM"
"LlamaWrapper"
],
...
}
```
However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model.
Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation".
Add to table of supported models in [generative_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/generative_models.md) or [multimodal_language_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/multimodal_language_models.md)