Misc clean up; Remove the support of jump forward (#4032)
This commit is contained in:
@@ -385,7 +385,7 @@
|
|||||||
"print(gen_response)\n",
|
"print(gen_response)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# parse the response\n",
|
"# parse the response\n",
|
||||||
"parse_url = f\"http://localhost:{port}/function_call\"\n",
|
"parse_url = f\"http://localhost:{port}/parse_function_call\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"function_call_input = {\n",
|
"function_call_input = {\n",
|
||||||
" \"text\": gen_response,\n",
|
" \"text\": gen_response,\n",
|
||||||
|
|||||||
@@ -1,72 +1,284 @@
|
|||||||
# Sampling Parameters
|
# Sampling Parameters in SGLang Runtime
|
||||||
|
|
||||||
This doc describes the sampling parameters of the SGLang Runtime.
|
This doc describes the sampling parameters of the SGLang Runtime.
|
||||||
It is the low-level endpoint of the runtime.
|
It is the low-level endpoint of the runtime.
|
||||||
If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API](https://docs.sglang.ai/backend/openai_api_completions.html).
|
If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API](../backend/openai_api_completions.ipynb).
|
||||||
|
|
||||||
## `/generate` Endpoint
|
The `/generate` endpoint accepts the following arguments in the JSON format. You can code examples below.
|
||||||
|
|
||||||
The `/generate` endpoint accepts the following parameters in JSON format. For in detail usage see the [native api doc](https://docs.sglang.ai/backend/native_api.html).
|
```python
|
||||||
|
@dataclass
|
||||||
|
class GenerateReqInput:
|
||||||
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
|
text: Optional[Union[List[str], str]] = None
|
||||||
|
# The token ids for text; one can specify either text or input_ids
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
|
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
||||||
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||||
|
# See also python/sglang/srt/utils.py:load_image.
|
||||||
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
|
# The sampling_params. See descriptions below.
|
||||||
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
|
# The request id.
|
||||||
|
rid: Optional[Union[List[str], str]] = None
|
||||||
|
# Whether to return logprobs.
|
||||||
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
|
# If return logprobs, the start location in the prompt for returning logprobs.
|
||||||
|
# By default, this value is "-1", which means it will only return logprobs for output tokens.
|
||||||
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
|
# If return logprobs, the number of top logprobs to return at each position.
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
|
# If return logprobs, the token ids to return logprob for.
|
||||||
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
|
# Whether to detokenize tokens in text in the returned logprobs.
|
||||||
|
return_text_in_logprobs: bool = False
|
||||||
|
# Whether to stream output.
|
||||||
|
stream: bool = False
|
||||||
|
|
||||||
* `prompt`: The input prompt. Can be a single prompt or a batch of prompts. `Optional[Union[List[str], str]] = None`
|
# The modalities of the image data [image, multi-images, video]
|
||||||
* `input_ids`: Alternative to `text`. Specify the input as token IDs instead of text. `Optional[Union[List[List[int]], List[int]]] = None`
|
modalities: Optional[List[str]] = None
|
||||||
* `sampling_params`: The sampling parameters as described in the sections below. `Optional[Union[List[Dict], Dict]] = None`
|
# LoRA related
|
||||||
* `return_logprob`: Whether to return log probabilities for tokens. `Optional[Union[List[bool], bool]] = None`
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
* `logprob_start_len`: If returning log probabilities, specifies the start position in the prompt. Default is "-1" which returns logprobs only for output tokens. `Optional[Union[List[int], int]] = None`
|
|
||||||
* `top_logprobs_num`: If returning log probabilities, specifies the number of top logprobs to return at each position. `Optional[Union[List[int], int]] = None`
|
|
||||||
* `stream`: Whether to stream the output. `bool = False`
|
|
||||||
* `lora_path`: Path to LoRA weights. `Optional[Union[List[Optional[str]], Optional[str]]] = None`
|
|
||||||
* `custom_logit_processor`: Custom logit processor for advanced sampling control. For usage see below. `Optional[Union[List[Optional[str]], str]] = None`
|
|
||||||
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. `bool = False`
|
|
||||||
|
|
||||||
## Sampling params
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
||||||
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
||||||
|
# Use the processor's `to_str()` method to generate the serialized string.
|
||||||
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
||||||
|
|
||||||
### Core Parameters
|
# Whether to return hidden states
|
||||||
|
return_hidden_states: bool = False
|
||||||
|
```
|
||||||
|
|
||||||
* `max_new_tokens`: The maximum output length measured in tokens. `int = 128`
|
The `sampling_params` follows this format
|
||||||
* `stop`: One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. `Optional[Union[str, List[str]]] = None`
|
|
||||||
* `stop_token_ids`: Provide stop words in form of token ids. Generation will stop if one of these token ids is sampled. `Optional[List[int]] = []`
|
|
||||||
* `temperature`: [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, higher temperature leads to more diversity. `float = 1.0`
|
|
||||||
* `top_p`: [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. `top_p: float = 1.0`
|
|
||||||
* `top_k`: [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. `int = -1`
|
|
||||||
* `min_p`: [Min-p](https://github.com/huggingface/transformers/issues/27670) samples from tokens with probability larger than `min_p * highest_token_probability`. `float = 0.0`
|
|
||||||
|
|
||||||
### Penalizers
|
```python
|
||||||
|
# The maximum number of output tokens
|
||||||
|
max_new_tokens: int = 128,
|
||||||
|
# Stop when hitting any of the strings in this list
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
# Stop when hitting any of the token_ids in this list
|
||||||
|
stop_token_ids: Optional[List[int]] = [],
|
||||||
|
# Sampling temperature
|
||||||
|
temperature: float = 1.0,
|
||||||
|
# Top-p sampling
|
||||||
|
top_p: float = 1.0,
|
||||||
|
# Top-k sampling
|
||||||
|
top_k: int = -1,
|
||||||
|
# Min-p sampling
|
||||||
|
min_p: float = 0.0,
|
||||||
|
# Do parallel sampling and return `n` outputs.
|
||||||
|
n: int = 1,
|
||||||
|
|
||||||
To use penalizers you will need to `--disable-overlap`. Please note that this might degrade performance.
|
## Structured Outputs
|
||||||
|
# Only one of the below three can be set for a request.
|
||||||
|
|
||||||
* `frequency_penalty`: Penalizes tokens based on their frequency in generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of penalization grows linearly with each appearance of a token. `float = 0.0`
|
# Constrain the output to follow a given JSON schema.
|
||||||
* `presence_penalty`: Penalizes tokens if they appeared in the generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of the penalization is constant if a token occured. `float = 0.0`
|
json_schema: Optional[str] = None,
|
||||||
* `repetition_penalty`: Penalizes tokens if they appeared in prompt or generation so far. Must be between `0` and `2` where numbers smaller than `1` encourage repeatment of tokens and numbers larger than `2` encourages sampling of new tokens. The penalization scales multiplicatively. `float = 0.0`
|
# Constrain the output to follow a given regular expression.
|
||||||
* `min_new_tokens`: Forces the model to generate at least `min_new_tokens` until a stop word or EOS token is sampled. Note that this might lead to unintended behavior for example if the distribution is highly skewed towards these tokens. `int = 0`
|
regex: Optional[str] = None,
|
||||||
|
# Constrain the output to follow a given EBNF grammar.
|
||||||
|
ebnf: Optional[str] = None,
|
||||||
|
|
||||||
### Constrained decoding
|
## Penalties
|
||||||
|
|
||||||
Please refer to our dedicated guide on [constrained decoding](https://docs.sglang.ai/backend/structured_outputs.html#Native-API-and-SGLang-Runtime-(SRT)) for the following parameters.
|
# Float that penalizes new tokens based on their frequency in the generated text so far.
|
||||||
|
# Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to
|
||||||
|
# repeat tokens. Must be -2 <= value <= 2. Setting to 0 (default) will disable this penalty.
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
# Float that penalizes new tokens based on whether they appear in the generated text so far.
|
||||||
|
# Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat
|
||||||
|
# tokens. Must be -2 <= value <= 2. Setting to 0 (default) will disable this penalty.
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
# Guides inference to generate at least this number of tokens by penalizing logits of tokenizer's
|
||||||
|
# EOS token and `stop_token_ids` to -inf, until the output token reaches given length.
|
||||||
|
# Note that any of the `stop` string can be generated before reaching `min_new_tokens`, as it is
|
||||||
|
# difficult to infer the correct token ID by given `stop` strings.
|
||||||
|
# Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty.
|
||||||
|
min_new_tokens: int = 0,
|
||||||
|
|
||||||
* `json_schema`: `Optional[str] = None`
|
# Whether to ignore EOS token
|
||||||
* `regex`: `Optional[str] = None`
|
ignore_eos: bool = False,
|
||||||
* `ebnf`: `Optional[str] = None`
|
# Whether to skip the special tokens during detokenization
|
||||||
|
skip_special_tokens: bool = True,
|
||||||
|
# Whether to add spaces between special tokens during detokenization
|
||||||
|
spaces_between_special_tokens: bool = True,
|
||||||
|
|
||||||
### Other options
|
## Custom Parameters for Custom Logit Processor.
|
||||||
|
# A dictionary of custom parameters for the custom logit processor.
|
||||||
|
# The custom logit processor takes a list of dictionaries as input, where each
|
||||||
|
# dictionary is the custom parameters for one token in a batch of the input.
|
||||||
|
# See also python/sglang/srt/sampling/custom_logit_processor.py
|
||||||
|
custom_params: Optional[Dict[str, Any]] = None,
|
||||||
|
```
|
||||||
|
|
||||||
* `n`: Specifies the number of output sequences to generate per request. (Generating multiple outputs in one request (n > 1) is discouraged; repeat the same prompts for several times offer better control and efficiency.) `int = 1`
|
## Examples
|
||||||
* `spaces_between_special_tokens`: Whether or not to add spaces between special tokens during detokenization. `bool = True`
|
|
||||||
* `no_stop_trim`: Don't trim stop words or EOS token from the generated text. `bool = False`
|
|
||||||
* `ignore_eos`: Don't stop generation when EOS token is sampled. `bool = False`
|
|
||||||
* `skip_special_tokens`: Remove special tokens during decoding. `bool = True`
|
|
||||||
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. `Optional[List[Optional[Dict[str, Any]]]] = None`
|
|
||||||
|
|
||||||
|
### Normal
|
||||||
|
Launch a server
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Send a request
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
```
|
||||||
|
|
||||||
|
### Streaming
|
||||||
|
Send a request and stream the output
|
||||||
|
```python
|
||||||
|
import requests, json
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev = 0
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
output = data["text"].strip()
|
||||||
|
print(output[prev:], end="", flush=True)
|
||||||
|
prev = len(output)
|
||||||
|
print("")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi modal
|
||||||
|
|
||||||
|
Launch a server
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --chat-template chatml-llava
|
||||||
|
```
|
||||||
|
|
||||||
|
Download an image
|
||||||
|
```
|
||||||
|
curl -o example_image.png -L https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Send a request
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n",
|
||||||
|
"image_data": "example_image.png",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
```
|
||||||
|
|
||||||
|
The `image_data` can be a file name, a URL, or a base64 encoded string. See also `python/sglang/srt/utils.py:load_image`.
|
||||||
|
Streaming is supported in a similar manner as [above](#streaming).
|
||||||
|
|
||||||
|
### Structured Outputs (JSON, Regex, EBNF)
|
||||||
|
You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.
|
||||||
|
|
||||||
|
SGLang supports two grammar backends:
|
||||||
|
|
||||||
|
- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.
|
||||||
|
- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.
|
||||||
|
- XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)
|
||||||
|
|
||||||
|
Initialize the XGrammar backend using `--grammar-backend xgrammar` flag
|
||||||
|
```bash
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||||
|
--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: outlines)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
json_schema = json.dumps({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||||
|
"population": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "population"],
|
||||||
|
})
|
||||||
|
|
||||||
|
# JSON (works with both Outlines and XGrammar)
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 64,
|
||||||
|
"json_schema": json_schema,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
|
||||||
|
# Regular expression (Outlines backend only)
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": "Paris is the capital of",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 64,
|
||||||
|
"regex": "(France|England)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
|
||||||
|
# EBNF (XGrammar backend only)
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": "Write a greeting.",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 64,
|
||||||
|
"ebnf": 'root ::= "Hello" | "Hi" | "Hey"',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(response.json())
|
||||||
|
```
|
||||||
### Custom Logit Processor
|
### Custom Logit Processor
|
||||||
|
|
||||||
Launch a server with `--enable-custom-logit-processor` flag on.
|
Launch a server with `--enable-custom-logit-processor` flag on.
|
||||||
|
|
||||||
```
|
```
|
||||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor
|
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor
|
||||||
```
|
```
|
||||||
|
|
||||||
Define a custom logit processor that will always sample a specific token id.
|
Define a custom logit processor that will always sample a specific token id.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
|
|
||||||
@@ -89,7 +301,6 @@ class DeterministicLogitProcessor(CustomLogitProcessor):
|
|||||||
```
|
```
|
||||||
|
|
||||||
Send a request
|
Send a request
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
### API configuration
|
### API configuration
|
||||||
|
|
||||||
* `api_key`: Sets an API key for the server and the OpenAI-compatible API.
|
* `api_key`: Sets an API key for the server and the OpenAI-compatible API.
|
||||||
* `file_storage_pth`: Directory for storing uploaded or generated files from API calls.
|
* `file_storage_path`: Directory for storing uploaded or generated files from API calls.
|
||||||
* `enable_cache_report`: If set, includes detailed usage of cached tokens in the response usage.
|
* `enable_cache_report`: If set, includes detailed usage of cached tokens in the response usage.
|
||||||
|
|
||||||
## Parallelism
|
## Parallelism
|
||||||
@@ -162,7 +162,6 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
*Note: We recommend to stay with the defaults and only use these options for debugging for best possible performance.*
|
*Note: We recommend to stay with the defaults and only use these options for debugging for best possible performance.*
|
||||||
|
|
||||||
* `disable_radix_cache`: Disable [Radix](https://lmsys.org/blog/2024-01-17-sglang/) backend for prefix caching.
|
* `disable_radix_cache`: Disable [Radix](https://lmsys.org/blog/2024-01-17-sglang/) backend for prefix caching.
|
||||||
* `disable_jump_forward`: Disable [jump-forward](https://lmsys.org/blog/2024-02-05-compressed-fsm/#our-method-jump-forward-decoding-with-a-compressed-finite-state-machine) for outlines grammar backend.
|
|
||||||
* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. Use if encountering uncorrectable CUDA ECC errors.
|
* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. Use if encountering uncorrectable CUDA ECC errors.
|
||||||
* `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph.
|
* `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph.
|
||||||
* `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend.
|
* `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend.
|
||||||
|
|||||||
@@ -47,7 +47,7 @@
|
|||||||
"server_process, port = launch_server_cmd(\n",
|
"server_process, port = launch_server_cmd(\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
|
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
|
||||||
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
" --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
||||||
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
|
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
@@ -104,7 +104,7 @@
|
|||||||
"server_process, port = launch_server_cmd(\n",
|
"server_process, port = launch_server_cmd(\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
|
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
|
||||||
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
" --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
||||||
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
|
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
|
||||||
" --enable-torch-compile --cuda-graph-max-bs 2\n",
|
" --enable-torch-compile --cuda-graph-max-bs 2\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
@@ -175,7 +175,7 @@
|
|||||||
"server_process, port = launch_server_cmd(\n",
|
"server_process, port = launch_server_cmd(\n",
|
||||||
" \"\"\"\n",
|
" \"\"\"\n",
|
||||||
"python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n",
|
"python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n",
|
||||||
" --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n",
|
" --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n",
|
||||||
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n",
|
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n",
|
||||||
" --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n",
|
" --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
|
|||||||
@@ -43,4 +43,4 @@ If you want to contribute but don’t have a specific idea in mind, pick issues
|
|||||||
|
|
||||||
If you have any questions or want to start a discussion, please feel free to ask in our [Slack channel](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2um0ad92q-LkU19KQTxCGzlCgRiOiQEw).
|
If you have any questions or want to start a discussion, please feel free to ask in our [Slack channel](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2um0ad92q-LkU19KQTxCGzlCgRiOiQEw).
|
||||||
|
|
||||||
Thank you for your interest in SGLang—**happy coding**!
|
Thank you for your interest in SGLang. Happy coding!
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ srun --ntasks=2 --nodes=2 --output="SLURM_Logs/%x_%j_node$SLURM_NODEID.out" \
|
|||||||
--model-path "$model" \
|
--model-path "$model" \
|
||||||
--grammar-backend "xgrammar" \
|
--grammar-backend "xgrammar" \
|
||||||
--tp "$tp_size" \
|
--tp "$tp_size" \
|
||||||
--nccl-init-addr "$NCCL_INIT_ADDR" \
|
--dist-init-addr "$NCCL_INIT_ADDR" \
|
||||||
--nnodes 2 \
|
--nnodes 2 \
|
||||||
--node-rank "$SLURM_NODEID" &
|
--node-rank "$SLURM_NODEID" &
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
You can install SGLang using any of the methods below.
|
You can install SGLang using any of the methods below.
|
||||||
|
|
||||||
For running DeepSeek V3/R1, refer to [DeepSeek V3 Support](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3). It is recommended to use the [latest version](https://pypi.org/project/sglang/#history) and deploy it with [Docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended) to avoid environment-related problems.
|
For running DeepSeek V3/R1, refer to [DeepSeek V3 Support](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3). It is recommended to use the latest version and deploy it with [Docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended) to avoid environment-related issues.
|
||||||
|
|
||||||
|
It is recommended to use uv to install the dependencies for faster installation:
|
||||||
|
|
||||||
We recommend using uv to install the dependencies with a higher installation speed:
|
|
||||||
## Method 1: With pip or uv
|
## Method 1: With pip or uv
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -13,12 +14,11 @@ pip install uv
|
|||||||
uv pip install "sglang[all]>=0.4.3.post2" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
|
uv pip install "sglang[all]>=0.4.3.post2" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
|
||||||
```
|
```
|
||||||
|
|
||||||
**Quick Fixes to Installation**
|
**Quick Fixes to Common Problems**
|
||||||
|
|
||||||
- SGLang currently uses torch 2.5, so you need to install flashinfer for torch 2.5. If you want to install flashinfer separately, please refer to [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html). Please note that the FlashInfer pypi package is called `flashinfer-python` instead of `flashinfer`.
|
- SGLang currently uses torch 2.5, so you need to install flashinfer for torch 2.5. If you want to install flashinfer separately, please refer to [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html). Please note that the FlashInfer pypi package is called `flashinfer-python` instead of `flashinfer`.
|
||||||
|
|
||||||
- If you encounter `OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root`, please try either of the following solutions:
|
- If you encounter `OSError: CUDA_HOME environment variable is not set`. Please set it to your CUDA install root with either of the following solutions:
|
||||||
|
|
||||||
1. Use `export CUDA_HOME=/usr/local/cuda-<your-cuda-version>` to set the `CUDA_HOME` environment variable.
|
1. Use `export CUDA_HOME=/usr/local/cuda-<your-cuda-version>` to set the `CUDA_HOME` environment variable.
|
||||||
2. Install FlashInfer first following [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html), then install SGLang as described above.
|
2. Install FlashInfer first following [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html), then install SGLang as described above.
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ def main():
|
|||||||
llm = sgl.Engine(
|
llm = sgl.Engine(
|
||||||
model_path="meta-llama/Llama-2-7b-chat-hf",
|
model_path="meta-llama/Llama-2-7b-chat-hf",
|
||||||
speculative_algorithm="EAGLE",
|
speculative_algorithm="EAGLE",
|
||||||
speculative_draft_model_path="lmzheng/sglang-EAGLE-llama2-chat-7B",
|
speculative_draft_model_path="lmsys/sglang-EAGLE-llama2-chat-7B",
|
||||||
speculative_num_steps=3,
|
speculative_num_steps=3,
|
||||||
speculative_eagle_topk=4,
|
speculative_eagle_topk=4,
|
||||||
speculative_num_draft_tokens=16,
|
speculative_num_draft_tokens=16,
|
||||||
@@ -52,7 +52,7 @@ srt = [
|
|||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||||
srt_hip = ["sglang[runtime_common]", "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
srt_hip = ["sglang[runtime_common]", "sgl-kernel==0.0.3.post6", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
||||||
|
|
||||||
# xpu is not enabled in public vllm and torch whl,
|
# xpu is not enabled in public vllm and torch whl,
|
||||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||||
|
|||||||
@@ -12,6 +12,5 @@
|
|||||||
- `global_config.py`: The global configs and constants.
|
- `global_config.py`: The global configs and constants.
|
||||||
- `launch_server.py`: The entry point for launching the local server.
|
- `launch_server.py`: The entry point for launching the local server.
|
||||||
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
|
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
|
||||||
- `profiler.py`: Profile a running server.
|
|
||||||
- `utils.py`: Common utilities.
|
- `utils.py`: Common utilities.
|
||||||
- `version.py`: Version info.
|
- `version.py`: Version info.
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
raise ValueError("bench_latency.py has been renamed to bench_one_batch.py")
|
|
||||||
@@ -4,6 +4,13 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
class GlobalConfig:
|
class GlobalConfig:
|
||||||
|
"""
|
||||||
|
Store some global constants.
|
||||||
|
|
||||||
|
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
|
||||||
|
many global runtime arguments as well.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Verbosity level
|
# Verbosity level
|
||||||
# 0: do not output anything
|
# 0: do not output anything
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
|
|||||||
grammar_backend = OutlinesGrammarBackend(
|
grammar_backend = OutlinesGrammarBackend(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||||
allow_jump_forward=not server_args.disable_jump_forward,
|
|
||||||
)
|
)
|
||||||
elif server_args.grammar_backend == "xgrammar":
|
elif server_args.grammar_backend == "xgrammar":
|
||||||
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
whitespace_pattern: bool,
|
whitespace_pattern: bool,
|
||||||
allow_jump_forward: bool,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -140,7 +139,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
self.outlines_tokenizer.vocabulary = (
|
self.outlines_tokenizer.vocabulary = (
|
||||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||||
)
|
)
|
||||||
self.allow_jump_forward = allow_jump_forward
|
|
||||||
self.whitespace_pattern = whitespace_pattern
|
self.whitespace_pattern = whitespace_pattern
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||||
@@ -172,9 +170,6 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.allow_jump_forward:
|
|
||||||
jump_forward_map = OutlinesJumpForwardMap(regex)
|
|
||||||
else:
|
|
||||||
jump_forward_map = None
|
jump_forward_map = None
|
||||||
return OutlinesGrammar(guide, jump_forward_map)
|
return OutlinesGrammar(guide, jump_forward_map)
|
||||||
|
|
||||||
|
|||||||
@@ -438,8 +438,8 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/function_call")
|
@app.post("/parse_function_call")
|
||||||
async def function_call_request(obj: ParseFunctionCallReq, request: Request):
|
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
|
||||||
"""
|
"""
|
||||||
A native API endpoint to parse function calls from a text.
|
A native API endpoint to parse function calls from a text.
|
||||||
"""
|
"""
|
||||||
@@ -492,7 +492,7 @@ def available_models():
|
|||||||
@app.post("/v1/files")
|
@app.post("/v1/files")
|
||||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||||
return await v1_files_create(
|
return await v1_files_create(
|
||||||
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
|
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
|
|||||||
@@ -19,9 +19,8 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
create_flashinfer_kv_indices_triton,
|
create_flashinfer_kv_indices_triton,
|
||||||
)
|
)
|
||||||
@@ -34,7 +34,6 @@ if is_flashinfer_available():
|
|||||||
BatchMLAPagedAttentionWrapper,
|
BatchMLAPagedAttentionWrapper,
|
||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
create_flashinfer_kv_indices_triton,
|
create_flashinfer_kv_indices_triton,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 <
|
|||||||
class DecodeStatus:
|
class DecodeStatus:
|
||||||
"""Store the status of incremental decoding."""
|
"""Store the status of incremental decoding."""
|
||||||
|
|
||||||
vid: int
|
|
||||||
decoded_text: str
|
decoded_text: str
|
||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
surr_offset: int
|
surr_offset: int
|
||||||
@@ -143,10 +142,8 @@ class DetokenizerManager:
|
|||||||
read_ids, surr_ids = [], []
|
read_ids, surr_ids = [], []
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
rid = recv_obj.rids[i]
|
rid = recv_obj.rids[i]
|
||||||
vid = recv_obj.vids[i]
|
if rid not in self.decode_status:
|
||||||
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
|
|
||||||
s = DecodeStatus(
|
s = DecodeStatus(
|
||||||
vid=vid,
|
|
||||||
decoded_text=recv_obj.decoded_texts[i],
|
decoded_text=recv_obj.decoded_texts[i],
|
||||||
decode_ids=recv_obj.decode_ids[i],
|
decode_ids=recv_obj.decode_ids[i],
|
||||||
surr_offset=0,
|
surr_offset=0,
|
||||||
|
|||||||
@@ -376,8 +376,6 @@ class BatchTokenIDOut:
|
|||||||
# The finish reason
|
# The finish reason
|
||||||
finished_reasons: List[BaseFinishReason]
|
finished_reasons: List[BaseFinishReason]
|
||||||
# For incremental decoding
|
# For incremental decoding
|
||||||
# The version id to sync decode status with in detokenizer_manager
|
|
||||||
vids: List[int]
|
|
||||||
decoded_texts: List[str]
|
decoded_texts: List[str]
|
||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
read_offsets: List[int]
|
read_offsets: List[int]
|
||||||
|
|||||||
@@ -296,7 +296,6 @@ class Req:
|
|||||||
# 1: surr_offset
|
# 1: surr_offset
|
||||||
# 2: read_offset
|
# 2: read_offset
|
||||||
# 3: last token
|
# 3: last token
|
||||||
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
|
||||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||||
self.read_offset = None
|
self.read_offset = None
|
||||||
self.decoded_text = ""
|
self.decoded_text = ""
|
||||||
@@ -357,11 +356,6 @@ class Req:
|
|||||||
) = None
|
) = None
|
||||||
self.hidden_states = []
|
self.hidden_states = []
|
||||||
|
|
||||||
# Logprobs (internal values)
|
|
||||||
# The tokens is prefilled but need to be considered as decode tokens
|
|
||||||
# and should be updated for the decode logprobs
|
|
||||||
self.last_update_decode_tokens = 0
|
|
||||||
|
|
||||||
# Embedding (return values)
|
# Embedding (return values)
|
||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
|
||||||
@@ -500,68 +494,6 @@ class Req:
|
|||||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||||
return
|
return
|
||||||
|
|
||||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
|
||||||
if self.origin_input_text is None:
|
|
||||||
# Recovering text can only use unpadded ids
|
|
||||||
self.origin_input_text = self.tokenizer.decode(
|
|
||||||
self.origin_input_ids_unpadded
|
|
||||||
)
|
|
||||||
|
|
||||||
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
|
||||||
all_ids = self.tokenizer.encode(all_text)
|
|
||||||
if not all_ids:
|
|
||||||
logger.warning("Encoded all_text resulted in empty all_ids")
|
|
||||||
return False
|
|
||||||
|
|
||||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
|
||||||
if prompt_tokens > len(all_ids):
|
|
||||||
logger.warning("prompt_tokens is larger than encoded all_ids")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
|
||||||
# TODO(lsyin): fix token fusion
|
|
||||||
logger.warning(
|
|
||||||
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
old_output_ids = self.output_ids
|
|
||||||
self.output_ids = all_ids[prompt_tokens:]
|
|
||||||
self.decoded_text = self.decoded_text + jump_forward_str
|
|
||||||
self.surr_offset = prompt_tokens
|
|
||||||
self.read_offset = len(all_ids)
|
|
||||||
|
|
||||||
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
|
||||||
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
|
||||||
surr_text_ = self.tokenizer.decode(
|
|
||||||
all_ids[self.read_offset - i : self.read_offset]
|
|
||||||
)
|
|
||||||
if not surr_text_.endswith("<EFBFBD>"):
|
|
||||||
self.surr_offset = self.read_offset - i
|
|
||||||
break
|
|
||||||
|
|
||||||
# update the inner state of the grammar
|
|
||||||
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
|
||||||
|
|
||||||
if self.return_logprob:
|
|
||||||
# For fast-forward part's logprobs
|
|
||||||
k = 0
|
|
||||||
for i, old_id in enumerate(old_output_ids):
|
|
||||||
if old_id == self.output_ids[i]:
|
|
||||||
k = k + 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
|
||||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
|
||||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
|
||||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
|
||||||
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
|
|
||||||
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
|
|
||||||
self.logprob_start_len = prompt_tokens + k
|
|
||||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def reset_for_retract(self):
|
def reset_for_retract(self):
|
||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
@@ -574,8 +506,6 @@ class Req:
|
|||||||
self.is_chunked = 0
|
self.is_chunked = 0
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx = None
|
||||||
|
|
||||||
self.last_update_decode_tokens = 0
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"Req(rid={self.rid}, "
|
f"Req(rid={self.rid}, "
|
||||||
@@ -672,7 +602,6 @@ class ScheduleBatch:
|
|||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
spec_algorithm: SpeculativeAlgorithm,
|
spec_algorithm: SpeculativeAlgorithm,
|
||||||
enable_custom_logit_processor: bool,
|
enable_custom_logit_processor: bool,
|
||||||
return_hidden_states: bool = False,
|
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -687,7 +616,7 @@ class ScheduleBatch:
|
|||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||||
return_hidden_states=return_hidden_states,
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -1091,59 +1020,6 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
return retracted_reqs, new_estimate_ratio
|
return retracted_reqs, new_estimate_ratio
|
||||||
|
|
||||||
def check_for_jump_forward(self, pad_input_ids_func):
|
|
||||||
jump_forward_reqs = []
|
|
||||||
keep_indices = set(i for i in range(len(self.reqs)))
|
|
||||||
|
|
||||||
for i, req in enumerate(self.reqs):
|
|
||||||
if req.grammar is not None:
|
|
||||||
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
|
||||||
if jump_helper:
|
|
||||||
suffix_ids, _ = jump_helper
|
|
||||||
|
|
||||||
# Current ids, for cache and revert
|
|
||||||
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
|
||||||
cur_output_ids = req.output_ids
|
|
||||||
|
|
||||||
req.output_ids.extend(suffix_ids)
|
|
||||||
decode_res, new_text = req.get_next_inc_detokenization()
|
|
||||||
if not decode_res:
|
|
||||||
req.output_ids = cur_output_ids
|
|
||||||
continue
|
|
||||||
|
|
||||||
(
|
|
||||||
jump_forward_str,
|
|
||||||
next_state,
|
|
||||||
) = req.grammar.jump_forward_str_state(jump_helper)
|
|
||||||
|
|
||||||
# Make the incrementally decoded text part of jump_forward_str
|
|
||||||
# so that the UTF-8 will not corrupt
|
|
||||||
jump_forward_str = new_text + jump_forward_str
|
|
||||||
if not req.jump_forward_and_retokenize(
|
|
||||||
jump_forward_str, next_state
|
|
||||||
):
|
|
||||||
req.output_ids = cur_output_ids
|
|
||||||
continue
|
|
||||||
|
|
||||||
# The decode status has diverged from detokenizer_manager
|
|
||||||
req.vid += 1
|
|
||||||
|
|
||||||
# insert the old request into tree_cache
|
|
||||||
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
|
||||||
|
|
||||||
# re-applying image padding
|
|
||||||
if req.image_inputs is not None:
|
|
||||||
req.origin_input_ids = pad_input_ids_func(
|
|
||||||
req.origin_input_ids_unpadded, req.image_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
jump_forward_reqs.append(req)
|
|
||||||
keep_indices.remove(i)
|
|
||||||
|
|
||||||
self.filter_batch(keep_indices=list(keep_indices))
|
|
||||||
|
|
||||||
return jump_forward_reqs
|
|
||||||
|
|
||||||
def prepare_encoder_info_decode(self):
|
def prepare_encoder_info_decode(self):
|
||||||
# Reset the encoder cached status
|
# Reset the encoder cached status
|
||||||
self.encoder_cached = [True] * len(self.reqs)
|
self.encoder_cached = [True] * len(self.reqs)
|
||||||
|
|||||||
@@ -150,7 +150,6 @@ class Scheduler:
|
|||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
self.schedule_policy = server_args.schedule_policy
|
self.schedule_policy = server_args.schedule_policy
|
||||||
self.disable_jump_forward = server_args.disable_jump_forward
|
|
||||||
self.lora_paths = server_args.lora_paths
|
self.lora_paths = server_args.lora_paths
|
||||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||||
@@ -251,9 +250,6 @@ class Scheduler:
|
|||||||
self.enable_overlap = False
|
self.enable_overlap = False
|
||||||
logger.info("Overlap scheduler is disabled for multimodal models.")
|
logger.info("Overlap scheduler is disabled for multimodal models.")
|
||||||
|
|
||||||
if self.enable_overlap:
|
|
||||||
self.disable_jump_forward = True
|
|
||||||
|
|
||||||
# Launch a tensor parallel worker
|
# Launch a tensor parallel worker
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
TpWorkerClass = TpModelWorkerClient
|
TpWorkerClass = TpModelWorkerClient
|
||||||
@@ -1024,11 +1020,8 @@ class Scheduler:
|
|||||||
if self.running_batch is not None
|
if self.running_batch is not None
|
||||||
else set([])
|
else set([])
|
||||||
)
|
)
|
||||||
return_hidden_states = False
|
|
||||||
# Get requests from the waiting queue to a new prefill batch
|
# Get requests from the waiting queue to a new prefill batch
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
if req.return_hidden_states:
|
|
||||||
return_hidden_states = True
|
|
||||||
if (
|
if (
|
||||||
self.lora_paths
|
self.lora_paths
|
||||||
and len(
|
and len(
|
||||||
@@ -1114,7 +1107,6 @@ class Scheduler:
|
|||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
self.spec_algorithm,
|
self.spec_algorithm,
|
||||||
self.server_args.enable_custom_logit_processor,
|
self.server_args.enable_custom_logit_processor,
|
||||||
return_hidden_states,
|
|
||||||
)
|
)
|
||||||
new_batch.prepare_for_extend()
|
new_batch.prepare_for_extend()
|
||||||
|
|
||||||
@@ -1168,14 +1160,6 @@ class Scheduler:
|
|||||||
self.min_new_token_ratio,
|
self.min_new_token_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for jump-forward
|
|
||||||
if not self.disable_jump_forward and batch.has_grammar:
|
|
||||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
|
||||||
self._extend_requests_to_queue(jump_forward_reqs)
|
|
||||||
if batch.is_empty():
|
|
||||||
self.batch_is_full = False
|
|
||||||
return None
|
|
||||||
|
|
||||||
if batch.batch_size() < initial_bs:
|
if batch.batch_size() < initial_bs:
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
@@ -1530,8 +1514,6 @@ class Scheduler:
|
|||||||
prefill (e.g., computing input token logprobs).
|
prefill (e.g., computing input token logprobs).
|
||||||
"""
|
"""
|
||||||
assert output.input_token_logprobs is not None
|
assert output.input_token_logprobs is not None
|
||||||
# It is for jump decoding that will be deprecated.
|
|
||||||
assert req.last_update_decode_tokens == 0
|
|
||||||
if req.input_token_logprobs is None:
|
if req.input_token_logprobs is None:
|
||||||
req.input_token_logprobs = []
|
req.input_token_logprobs = []
|
||||||
if req.temp_input_top_logprobs_val is None:
|
if req.temp_input_top_logprobs_val is None:
|
||||||
@@ -1658,50 +1640,12 @@ class Scheduler:
|
|||||||
self.add_input_logprob_return_values(
|
self.add_input_logprob_return_values(
|
||||||
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
||||||
)
|
)
|
||||||
if req.last_update_decode_tokens != 0:
|
|
||||||
# Some decode tokens are re-computed in an extend batch
|
|
||||||
req.output_token_logprobs_val.extend(
|
|
||||||
output.input_token_logprobs[
|
|
||||||
pt
|
|
||||||
+ num_input_logprobs
|
|
||||||
- 1
|
|
||||||
- req.last_update_decode_tokens : pt
|
|
||||||
+ num_input_logprobs
|
|
||||||
- 1
|
|
||||||
],
|
|
||||||
)
|
|
||||||
req.output_token_logprobs_idx.extend(
|
|
||||||
req.fill_ids[
|
|
||||||
len(req.fill_ids)
|
|
||||||
- req.last_update_decode_tokens : len(req.fill_ids)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
if req.last_update_decode_tokens != 0:
|
|
||||||
req.output_top_logprobs_val.extend(
|
|
||||||
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
|
|
||||||
)
|
|
||||||
req.output_top_logprobs_idx.extend(
|
|
||||||
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
|
||||||
)
|
|
||||||
|
|
||||||
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
||||||
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
||||||
|
|
||||||
if req.token_ids_logprob is not None:
|
if req.token_ids_logprob is not None:
|
||||||
if req.last_update_decode_tokens != 0:
|
|
||||||
req.output_token_ids_logprobs_val.extend(
|
|
||||||
output.input_token_ids_logprobs_val[i][
|
|
||||||
-req.last_update_decode_tokens :
|
|
||||||
]
|
|
||||||
)
|
|
||||||
req.output_token_ids_logprobs_idx.extend(
|
|
||||||
output.input_token_ids_logprobs_idx[i][
|
|
||||||
-req.last_update_decode_tokens :
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
req.output_token_ids_logprobs_val.append(
|
req.output_token_ids_logprobs_val.append(
|
||||||
output.next_token_token_ids_logprobs_val[i]
|
output.next_token_token_ids_logprobs_val[i]
|
||||||
)
|
)
|
||||||
@@ -1719,7 +1663,6 @@ class Scheduler:
|
|||||||
finished_reasons: List[BaseFinishReason] = []
|
finished_reasons: List[BaseFinishReason] = []
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
vids = []
|
|
||||||
decoded_texts = []
|
decoded_texts = []
|
||||||
decode_ids_list = []
|
decode_ids_list = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
@@ -1786,7 +1729,6 @@ class Scheduler:
|
|||||||
finished_reasons.append(
|
finished_reasons.append(
|
||||||
req.finished_reason.to_json() if req.finished_reason else None
|
req.finished_reason.to_json() if req.finished_reason else None
|
||||||
)
|
)
|
||||||
vids.append(req.vid)
|
|
||||||
decoded_texts.append(req.decoded_text)
|
decoded_texts.append(req.decoded_text)
|
||||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||||
decode_ids_list.append(decode_ids)
|
decode_ids_list.append(decode_ids)
|
||||||
@@ -1842,7 +1784,6 @@ class Scheduler:
|
|||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
rids,
|
rids,
|
||||||
finished_reasons,
|
finished_reasons,
|
||||||
vids,
|
|
||||||
decoded_texts,
|
decoded_texts,
|
||||||
decode_ids_list,
|
decode_ids_list,
|
||||||
read_offsets,
|
read_offsets,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
|||||||
from sglang.srt.utils import get_compiler_backend
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|||||||
@@ -26,8 +26,6 @@ from fastapi import HTTPException, Request, UploadFile
|
|||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from outlines.fsm.json_schema import convert_json_schema_to_str
|
from outlines.fsm.json_schema import convert_json_schema_to_str
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -165,24 +163,19 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
|
|||||||
else:
|
else:
|
||||||
chat_template_name = chat_template_arg
|
chat_template_name = chat_template_arg
|
||||||
|
|
||||||
# check chat-template
|
# Check chat-template
|
||||||
chat_template = get_chat_template_by_model_path(model_path)
|
# TODO:
|
||||||
if chat_template is not None:
|
# 1. Do not import any code from sglang.lang
|
||||||
official_chat_template = chat_template.name
|
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
|
||||||
used_chat_template = chat_template_name
|
|
||||||
if official_chat_template != used_chat_template:
|
|
||||||
logger.warning(
|
|
||||||
f"Using a chat_template: '{used_chat_template}', "
|
|
||||||
f"which is different from official chat template: '{official_chat_template}', "
|
|
||||||
f"This discrepancy may lead to performance degradation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
|
async def v1_files_create(
|
||||||
|
file: UploadFile, purpose: str, file_storage_path: str = None
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
global storage_dir
|
global storage_dir
|
||||||
if file_storage_pth:
|
if file_storage_path:
|
||||||
storage_dir = file_storage_pth
|
storage_dir = file_storage_path
|
||||||
# Read the file content
|
# Read the file content
|
||||||
file_content = await file.read()
|
file_content = await file.read()
|
||||||
|
|
||||||
|
|||||||
@@ -40,17 +40,23 @@ class SamplingParams:
|
|||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
min_new_tokens: int = 0,
|
min_new_tokens: int = 0,
|
||||||
spaces_between_special_tokens: bool = True,
|
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
json_schema: Optional[str] = None,
|
json_schema: Optional[str] = None,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
ebnf: Optional[str] = None,
|
ebnf: Optional[str] = None,
|
||||||
structural_tag: Optional[str] = None,
|
structural_tag: Optional[str] = None,
|
||||||
no_stop_trim: bool = False,
|
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
|
spaces_between_special_tokens: bool = True,
|
||||||
|
no_stop_trim: bool = False,
|
||||||
custom_params: Optional[Dict[str, Any]] = None,
|
custom_params: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.max_new_tokens = max_new_tokens
|
||||||
|
self.stop_strs = stop
|
||||||
|
if stop_token_ids:
|
||||||
|
self.stop_token_ids = set(stop_token_ids)
|
||||||
|
else:
|
||||||
|
self.stop_token_ids = None
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@@ -58,26 +64,21 @@ class SamplingParams:
|
|||||||
self.frequency_penalty = frequency_penalty
|
self.frequency_penalty = frequency_penalty
|
||||||
self.presence_penalty = presence_penalty
|
self.presence_penalty = presence_penalty
|
||||||
self.repetition_penalty = repetition_penalty
|
self.repetition_penalty = repetition_penalty
|
||||||
self.stop_strs = stop
|
|
||||||
if stop_token_ids:
|
|
||||||
self.stop_token_ids = set(stop_token_ids)
|
|
||||||
else:
|
|
||||||
self.stop_token_ids = None
|
|
||||||
self.max_new_tokens = max_new_tokens
|
|
||||||
self.min_new_tokens = min_new_tokens
|
self.min_new_tokens = min_new_tokens
|
||||||
self.ignore_eos = ignore_eos
|
|
||||||
self.skip_special_tokens = skip_special_tokens
|
|
||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
|
||||||
self.regex = regex
|
self.regex = regex
|
||||||
self.n = n
|
self.n = n
|
||||||
self.json_schema = json_schema
|
self.json_schema = json_schema
|
||||||
self.ebnf = ebnf
|
self.ebnf = ebnf
|
||||||
self.structural_tag = structural_tag
|
self.structural_tag = structural_tag
|
||||||
|
self.ignore_eos = ignore_eos
|
||||||
|
self.skip_special_tokens = skip_special_tokens
|
||||||
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
self.no_stop_trim = no_stop_trim
|
self.no_stop_trim = no_stop_trim
|
||||||
self.custom_params = custom_params
|
self.custom_params = custom_params
|
||||||
|
|
||||||
# Process some special cases
|
# Process some special cases
|
||||||
if self.temperature < _SAMPLING_EPS:
|
if self.temperature < _SAMPLING_EPS:
|
||||||
|
# top_k = 1 means greedy sampling
|
||||||
self.temperature = 1.0
|
self.temperature = 1.0
|
||||||
self.top_k = 1
|
self.top_k = 1
|
||||||
if self.top_k == -1:
|
if self.top_k == -1:
|
||||||
|
|||||||
@@ -15,21 +15,15 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import subprocess
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
create_checksum,
|
|
||||||
get_amdgpu_memory_capacity,
|
get_amdgpu_memory_capacity,
|
||||||
get_hpu_memory_capacity,
|
get_hpu_memory_capacity,
|
||||||
get_nvgpu_memory_capacity,
|
get_nvgpu_memory_capacity,
|
||||||
@@ -101,7 +95,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
# API related
|
# API related
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
file_storage_pth: str = "sglang_storage"
|
file_storage_path: str = "sglang_storage"
|
||||||
enable_cache_report: bool = False
|
enable_cache_report: bool = False
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
@@ -149,7 +143,6 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_jump_forward: bool = False
|
|
||||||
disable_cuda_graph: bool = False
|
disable_cuda_graph: bool = False
|
||||||
disable_cuda_graph_padding: bool = False
|
disable_cuda_graph_padding: bool = False
|
||||||
enable_nccl_nvls: bool = False
|
enable_nccl_nvls: bool = False
|
||||||
@@ -627,9 +620,9 @@ class ServerArgs:
|
|||||||
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--file-storage-pth",
|
"--file-storage-path",
|
||||||
type=str,
|
type=str,
|
||||||
default=ServerArgs.file_storage_pth,
|
default=ServerArgs.file_storage_path,
|
||||||
help="The path of the file storage in backend.",
|
help="The path of the file storage in backend.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -836,11 +829,6 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable RadixAttention for prefix caching.",
|
help="Disable RadixAttention for prefix caching.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--disable-jump-forward",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable jump-forward for grammar-guided decoding.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-cuda-graph",
|
"--disable-cuda-graph",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8
|
|||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
|
||||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B"
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
||||||
|
|
||||||
|
|
||||||
def is_in_ci():
|
def is_in_ci():
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
# single GPU
|
# single GPU
|
||||||
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B
|
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
@@ -17,3 +17,59 @@ For CUDA 12.1 or CUDA 12.4:
|
|||||||
```bash
|
```bash
|
||||||
pip3 install sgl-kernel
|
pip3 install sgl-kernel
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Developer Guide
|
||||||
|
|
||||||
|
## Development Environment Setup
|
||||||
|
|
||||||
|
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container).
|
||||||
|
|
||||||
|
Create and enter development container:
|
||||||
|
```bash
|
||||||
|
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
|
||||||
|
docker exec -it sglang_zhyncs /bin/zsh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
Third-party libraries:
|
||||||
|
|
||||||
|
- [CCCL](https://github.com/NVIDIA/cccl)
|
||||||
|
- [CUTLASS](https://github.com/NVIDIA/cutlass)
|
||||||
|
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
|
||||||
|
- [TurboMind](https://github.com/InternLM/turbomind)
|
||||||
|
|
||||||
|
### Kernel Development
|
||||||
|
|
||||||
|
Steps to add a new kernel:
|
||||||
|
|
||||||
|
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
|
||||||
|
2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
|
||||||
|
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
|
||||||
|
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
|
||||||
|
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
|
||||||
|
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
||||||
|
|
||||||
|
### Build & Install
|
||||||
|
|
||||||
|
Development build:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make build
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
|
||||||
|
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
|
||||||
|
|
||||||
|
### Testing & Benchmarking
|
||||||
|
|
||||||
|
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
|
||||||
|
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
|
||||||
|
3. Run test suite
|
||||||
|
|
||||||
|
### Release new version
|
||||||
|
|
||||||
|
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
# Developer Guide for sgl-kernel
|
|
||||||
|
|
||||||
## Development Environment Setup
|
|
||||||
|
|
||||||
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container).
|
|
||||||
|
|
||||||
Create and enter development container:
|
|
||||||
```bash
|
|
||||||
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
|
|
||||||
docker exec -it sglang_zhyncs /bin/zsh
|
|
||||||
```
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
### Dependencies
|
|
||||||
|
|
||||||
Third-party libraries:
|
|
||||||
|
|
||||||
- [CCCL](https://github.com/NVIDIA/cccl)
|
|
||||||
- [CUTLASS](https://github.com/NVIDIA/cutlass)
|
|
||||||
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
|
|
||||||
- [TurboMind](https://github.com/InternLM/turbomind)
|
|
||||||
|
|
||||||
### Kernel Development
|
|
||||||
|
|
||||||
Steps to add a new kernel:
|
|
||||||
|
|
||||||
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
|
|
||||||
2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
|
|
||||||
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
|
|
||||||
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
|
|
||||||
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
|
|
||||||
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
|
||||||
|
|
||||||
### Build & Install
|
|
||||||
|
|
||||||
Development build:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make build
|
|
||||||
```
|
|
||||||
|
|
||||||
Note:
|
|
||||||
|
|
||||||
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
|
|
||||||
|
|
||||||
### Testing & Benchmarking
|
|
||||||
|
|
||||||
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
|
|
||||||
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
|
|
||||||
3. Run test suite
|
|
||||||
|
|
||||||
### Release new version
|
|
||||||
|
|
||||||
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
|
|
||||||
@@ -100,6 +100,7 @@ sources = [
|
|||||||
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
|
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
|
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
|
||||||
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
|
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
|
||||||
|
"src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
|
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
|
||||||
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
|
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
||||||
@@ -108,7 +109,6 @@ sources = [
|
|||||||
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
|
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
|
||||||
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
|
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
|
||||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
|
||||||
"3rdparty/flashinfer/csrc/activation.cu",
|
"3rdparty/flashinfer/csrc/activation.cu",
|
||||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||||
"3rdparty/flashinfer/csrc/norm.cu",
|
"3rdparty/flashinfer/csrc/norm.cu",
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
|||||||
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||||
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From csrc/attention
|
||||||
|
*/
|
||||||
|
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/gemm
|
* From csrc/gemm
|
||||||
*/
|
*/
|
||||||
@@ -163,11 +168,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
|||||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||||
|
|
||||||
/*
|
|
||||||
* Other
|
|
||||||
*/
|
|
||||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(_kernels)
|
REGISTER_EXTENSION(_kernels)
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ class TestEBNFConstrained(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
setup_class(cls, "xgrammar", disable_overlap=False)
|
setup_class(cls, "xgrammar", disable_overlap=False)
|
||||||
cls.check_jump_forward = False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
@@ -238,12 +237,5 @@ class TestEBNFConstrained(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
setup_class(cls, "llguidance", disable_overlap=False)
|
|
||||||
cls.check_jump_forward = False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
setup_class(cls, backend="outlines", disable_overlap=False)
|
setup_class(cls, backend="outlines", disable_overlap=False)
|
||||||
cls.check_jump_forward = False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
@@ -134,26 +133,5 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
|||||||
list(executor.map(self.run_decode, json_schemas))
|
list(executor.map(self.run_decode, json_schemas))
|
||||||
|
|
||||||
|
|
||||||
class TestJumpForwardOutlinesBackend(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
setup_class(cls, backend="outlines", disable_overlap=True)
|
|
||||||
cls.check_jump_forward = True
|
|
||||||
|
|
||||||
|
|
||||||
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
setup_class(cls, backend="xgrammar", disable_overlap=False)
|
|
||||||
cls.check_jump_forward = False
|
|
||||||
|
|
||||||
|
|
||||||
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
setup_class(cls, backend="llguidance", disable_overlap=False)
|
|
||||||
cls.check_jump_forward = False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
is_in_ci,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
|
write_github_step_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +51,9 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.62)
|
self.assertGreater(metrics["score"], 0.62)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
|
||||||
|
|
||||||
def test_human_eval(self):
|
def test_human_eval(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -61,6 +66,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.40)
|
self.assertGreater(metrics["score"], 0.40)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(
|
||||||
|
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
|
||||||
|
)
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
def test_mgsm_en(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -73,6 +83,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.61)
|
self.assertGreater(metrics["score"], 0.61)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(
|
||||||
|
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_gener
|
|||||||
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
|
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
|
||||||
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email
|
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email
|
||||||
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting
|
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting
|
||||||
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_email
|
|
||||||
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_greeting
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -47,7 +45,6 @@ class TestRegexConstrained(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
setup_class(cls, "xgrammar", disable_overlap=False)
|
setup_class(cls, "xgrammar", disable_overlap=False)
|
||||||
cls.check_jump_forward = False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
@@ -179,20 +176,6 @@ class TestRegexConstrained(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestJumpForward(TestRegexConstrained):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
setup_class(cls, "xgrammar", disable_overlap=True)
|
|
||||||
cls.check_jump_forward = True
|
|
||||||
|
|
||||||
|
|
||||||
class TestJumpForwardLLGuidance(TestRegexConstrained):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
setup_class(cls, "llguidance", disable_overlap=True)
|
|
||||||
cls.check_jump_forward = True
|
|
||||||
|
|
||||||
|
|
||||||
class TestRegexConstrainedLLGuidance(TestRegexConstrained):
|
class TestRegexConstrainedLLGuidance(TestRegexConstrained):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user