diff --git a/docs/supported_models/support_new_models.md b/docs/supported_models/support_new_models.md index 06a884239..511a8f398 100644 --- a/docs/supported_models/support_new_models.md +++ b/docs/supported_models/support_new_models.md @@ -135,6 +135,182 @@ ModelRegistry.models.update(import_new_model_classes()) launch_server(server_args) ``` +## Example: Implementing and Serving a Llama Wrapper Model + +Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb). + +### Implementing Our Model + +To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit. + +Let's start by defining our model in a file called `llama_wrapper.py`. +The first step is to import the necessary libraries from SRT, which is SGLang's internal backend. + +```python +# In the file `llama_wrapper.py` + +import torch +from transformers import LlamaConfig +from typing import Optional +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +from sglang.srt.models.llama import LlamaForCausalLM +``` + +Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`. +Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219). +Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us. + +```python +class LlamaWrapper(LlamaForCausalLM): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config=config, quant_config=quant_config, prefix=prefix) +``` + +Now, we want to define the `forward` method, which is what will be called at inference time. +Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references. +To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py). + +```python + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: +``` + +We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method. +After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`). + +```python + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + + res: LogitsProcessorOutput = self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + ) +``` + +After receiving the logits for the next token, we can finally perform our biasing step. + +```python + orig_logits = res.next_token_logits + res.next_token_logits = torch.where( + orig_logits > 0, + orig_logits.sqrt(), + orig_logits + ) + + return res +``` +Now, our `LlamaWrapper` model is created and ready to be served! + +### Serving Our Model Via SGLang's Offline Engine + +The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server. + +First, create a new file called `run.py`. +Now, we must ensure that SGLang's `ModelRegistry` can find our model. +To do this, we first download the model's configuration and weights from Huggingface. + +```python +# In the file `run.py` + +import asyncio +from functools import lru_cache +from huggingface_hub import snapshot_download +from llama_wrapper import LlamaWrapper # Make sure to import our new model! +import sglang as sgl +from sglang.srt.models.registry import ModelRegistry + +# Make sure to request access to this model on Huggingface, then export your +# `HF_TOKEN` to download the model snapshot +llama_dir = snapshot_download( + repo_id="meta-llama/Llama-3.1-8B-Instruct", + local_dir="./llama_ckpt", +) +``` + +Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`. +That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model. + +```python +{ + "architectures": [ + # "LlamaForCausalLM" + "LlamaWrapper" + ], + ... +} +``` + +However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model. +Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation". + +```python +@lru_cache() +def import_new_model_classes(): + model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper} + return model_arch_name_to_cls + +ModelRegistry.models.update(import_new_model_classes()) +``` + +Lastly, when we create our `Engine`, we just pass in the path to the local model directory. +Then, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint. + +```python +def main(): + llm = sgl.Engine(model_path="./llama_ckpt") + sampling_params = {"temperature": 0.2, "top_k": 5} + prompts = [ + "Write a short, neutral self-introduction for a fictional character. Hello, my name is", + "Provide a concise factual statement about France’s capital city. The capital of France is", + "Explain possible future trends in artificial intelligence. The future of AI is", + ] + + asyncio.run(run_llm(llm, sampling_params, prompts)) + + llm.shutdown() + +async def run_llm( + llm, + sampling_params, + prompts, +) -> None: + outputs = await llm.async_generate(prompts, sampling_params) + + for prompt, output in zip(prompts, outputs): + print(f"\nPrompt: {prompt}") + print(f"Generated text: {output['text']}") + +if __name__ == "__main__": + main() +``` + +Now, when we call `python run.py`, we will get the outputs of our newly created model! + + ## Documentation Add to table of supported models in [generative_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/generative_models.md) or [multimodal_language_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/multimodal_language_models.md)