[doc] add walkthrough for implementing and hosting a simple llama wrapper m… (#10093)
This commit is contained in:
@@ -135,6 +135,182 @@ ModelRegistry.models.update(import_new_model_classes())
|
||||
launch_server(server_args)
|
||||
```
|
||||
|
||||
## Example: Implementing and Serving a Llama Wrapper Model
|
||||
|
||||
Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the [Offline Engine](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/offline_engine_api.ipynb).
|
||||
|
||||
### Implementing Our Model
|
||||
|
||||
To keep things simple, this new model will be a simple wrapper around [Llama 3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), and our goal will be just to bias the output logits for each `forward` call by taking the square root of each individual logit.
|
||||
|
||||
Let's start by defining our model in a file called `llama_wrapper.py`.
|
||||
The first step is to import the necessary libraries from SRT, which is SGLang's internal backend.
|
||||
|
||||
```python
|
||||
# In the file `llama_wrapper.py`
|
||||
|
||||
import torch
|
||||
from transformers import LlamaConfig
|
||||
from typing import Optional
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
```
|
||||
|
||||
Next, we declare a new `class` for our model and have it inherit from `LlamaForCausalLM`, which allows our model to access `LlamaForCausalLM`'s predefined modules and layers, such as `LlamaAttention` and `LlamaMLP`.
|
||||
Note that almost all model implementations take in `config` and `quant_config` as arguments for their `__init__` method; `config` and `quant_config` are passed in via [`model_loader/loader.py`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_loader/loader.py#L219).
|
||||
Because we have inherited from `LlamaForCausalLM`, we can pass our parameters directly to its constructor, which will set the member variables for us.
|
||||
|
||||
```python
|
||||
class LlamaWrapper(LlamaForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config=config, quant_config=quant_config, prefix=prefix)
|
||||
```
|
||||
|
||||
Now, we want to define the `forward` method, which is what will be called at inference time.
|
||||
Note that the signature for `forward` is essentially the same for any model; you can take a look at the other models defined in the [`models` directory](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/) for references.
|
||||
To see where exactly `forward` is called in the SGLang runtime's internals, take a look at [`forward_decode`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1705) and [`forward_extend`](https://github.com/sgl-project/sglang/blob/bf72b80122fd888bf619d17b96fa3e323ab809fc/python/sglang/srt/model_executor/model_runner.py#L1724) in the [`ModelRunner` class](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/model_runner.py).
|
||||
|
||||
```python
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
get_embedding: bool = False,
|
||||
) -> LogitsProcessorOutput:
|
||||
```
|
||||
|
||||
We now call the `__call__` method for `self.model` (which is a member variable that `LlamaForCausalLM` defines in its `__init__` method), which eventually calls `LlamaForCausalLM`'s `forward` method.
|
||||
After that, we feed the `hidden_states` into our model's `LogitsProcessor` (again defined in `LlamaForCausalLM`).
|
||||
|
||||
```python
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
res: LogitsProcessorOutput = self.logits_processor(
|
||||
input_ids,
|
||||
hidden_states,
|
||||
self.lm_head,
|
||||
forward_batch,
|
||||
)
|
||||
```
|
||||
|
||||
After receiving the logits for the next token, we can finally perform our biasing step.
|
||||
|
||||
```python
|
||||
orig_logits = res.next_token_logits
|
||||
res.next_token_logits = torch.where(
|
||||
orig_logits > 0,
|
||||
orig_logits.sqrt(),
|
||||
orig_logits
|
||||
)
|
||||
|
||||
return res
|
||||
```
|
||||
Now, our `LlamaWrapper` model is created and ready to be served!
|
||||
|
||||
### Serving Our Model Via SGLang's Offline Engine
|
||||
|
||||
The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server.
|
||||
|
||||
First, create a new file called `run.py`.
|
||||
Now, we must ensure that SGLang's `ModelRegistry` can find our model.
|
||||
To do this, we first download the model's configuration and weights from Huggingface.
|
||||
|
||||
```python
|
||||
# In the file `run.py`
|
||||
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
from huggingface_hub import snapshot_download
|
||||
from llama_wrapper import LlamaWrapper # Make sure to import our new model!
|
||||
import sglang as sgl
|
||||
from sglang.srt.models.registry import ModelRegistry
|
||||
|
||||
# Make sure to request access to this model on Huggingface, then export your
|
||||
# `HF_TOKEN` to download the model snapshot
|
||||
llama_dir = snapshot_download(
|
||||
repo_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
local_dir="./llama_ckpt",
|
||||
)
|
||||
```
|
||||
|
||||
Now that we have our model on disk, we want to point it to `LlamaWrapper` by changing the `architectures` field in `./llama_ckpt/config.json` to be `LlamaWrapper`.
|
||||
That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use "LlamaWrapper" instead of "LlamaForCausalLM" as our model.
|
||||
|
||||
```python
|
||||
{
|
||||
"architectures": [
|
||||
# "LlamaForCausalLM"
|
||||
"LlamaWrapper"
|
||||
],
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
However, if we don't link our `LlamaWrapper` class to the "LlamaWrapper" registry keyword, then SGLang won't be able to find our model.
|
||||
Thus, to register our `LlamaWrapper`, we want to follow the steps in the above section titled "Registering an External Model Implementation".
|
||||
|
||||
```python
|
||||
@lru_cache()
|
||||
def import_new_model_classes():
|
||||
model_arch_name_to_cls = {"LlamaWrapper": LlamaWrapper}
|
||||
return model_arch_name_to_cls
|
||||
|
||||
ModelRegistry.models.update(import_new_model_classes())
|
||||
```
|
||||
|
||||
Lastly, when we create our `Engine`, we just pass in the path to the local model directory.
|
||||
Then, our `LlamaWrapper` is ready to be served; for this walkthrough, we will use SGLang `Engine`'s non-streaming asynchronous generation endpoint.
|
||||
|
||||
```python
|
||||
def main():
|
||||
llm = sgl.Engine(model_path="./llama_ckpt")
|
||||
sampling_params = {"temperature": 0.2, "top_k": 5}
|
||||
prompts = [
|
||||
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
|
||||
"Provide a concise factual statement about France’s capital city. The capital of France is",
|
||||
"Explain possible future trends in artificial intelligence. The future of AI is",
|
||||
]
|
||||
|
||||
asyncio.run(run_llm(llm, sampling_params, prompts))
|
||||
|
||||
llm.shutdown()
|
||||
|
||||
async def run_llm(
|
||||
llm,
|
||||
sampling_params,
|
||||
prompts,
|
||||
) -> None:
|
||||
outputs = await llm.async_generate(prompts, sampling_params)
|
||||
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print(f"\nPrompt: {prompt}")
|
||||
print(f"Generated text: {output['text']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
Now, when we call `python run.py`, we will get the outputs of our newly created model!
|
||||
|
||||
|
||||
## Documentation
|
||||
Add to table of supported models in [generative_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/generative_models.md) or [multimodal_language_models.md](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/multimodal_language_models.md)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user