Enable torch.compile for triton backend (#1422)
This commit is contained in:
23
README.md
23
README.md
@@ -205,28 +205,29 @@ print(response)
|
||||
It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
|
||||
|
||||
### Additional Server Arguments
|
||||
- Add `--tp 2` to enable multi-GPU tensor parallelism. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
|
||||
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2
|
||||
```
|
||||
- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
|
||||
- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total.
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2
|
||||
```
|
||||
- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`.
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --mem-fraction-static 0.7
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7
|
||||
```
|
||||
- See [hyperparameter_tuning.md](docs/en/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
||||
- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size.
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --chunked-prefill-size 4096
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
|
||||
```
|
||||
- To enable torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes.
|
||||
- To enable fp8 weight quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
|
||||
- To enable fp8 kv cache quanzation, you can add `--kv-cache-dtype fp8_e5m2`.
|
||||
- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
|
||||
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
|
||||
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
|
||||
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
|
||||
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
|
||||
- To enable DeepSeek MLA acceleration, add `--enable-mla`.
|
||||
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
|
||||
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port.
|
||||
```
|
||||
# Node 0
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0
|
||||
|
||||
@@ -479,7 +479,8 @@ def main(server_args, bench_args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO(kevin85421): Make the parser setup unit testable.
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
|
||||
@@ -22,7 +22,7 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -332,7 +332,6 @@ class TritonAttnBackend(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Lazy import to avoid the initialization of cuda context
|
||||
from sglang.srt.layers.triton_attention.decode_attention import (
|
||||
REDUCE_TORCH_TYPE,
|
||||
decode_attention_fwd,
|
||||
)
|
||||
from sglang.srt.layers.triton_attention.extend_attention import (
|
||||
@@ -343,9 +342,13 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
self.decode_attention_fwd = decode_attention_fwd
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
self.REDUCE_TORCH_TYPE = REDUCE_TORCH_TYPE
|
||||
self.num_head = model_runner.model_config.num_attention_heads
|
||||
|
||||
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
||||
self.reduce_dtype = torch.float32
|
||||
else:
|
||||
self.reduce_dtype = torch.float16
|
||||
|
||||
self.forward_metadata = None
|
||||
|
||||
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
|
||||
@@ -362,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
|
||||
attn_logits = torch.empty(
|
||||
(self.num_head, total_num_tokens),
|
||||
dtype=self.REDUCE_TORCH_TYPE,
|
||||
dtype=self.reduce_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
@@ -382,8 +385,11 @@ class TritonAttnBackend(AttentionBackend):
|
||||
(max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.cuda_graph_attn_logits = torch.empty(
|
||||
(self.num_head, self.cuda_graph_max_total_num_tokens),
|
||||
dtype=self.REDUCE_TORCH_TYPE,
|
||||
(
|
||||
self.num_head,
|
||||
self.cuda_graph_max_total_num_tokens,
|
||||
),
|
||||
dtype=self.reduce_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
@@ -403,13 +409,6 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_start_loc,
|
||||
self.cuda_graph_attn_logits,
|
||||
self.cuda_graph_max_seq_len,
|
||||
None,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
@@ -444,6 +443,10 @@ class TritonAttnBackend(AttentionBackend):
|
||||
return o
|
||||
|
||||
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
||||
|
||||
@@ -21,19 +21,9 @@ It supports page size = 1.
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
REDUCE_TORCH_TYPE = torch.float32
|
||||
else:
|
||||
REDUCE_TRITON_TYPE = tl.float16
|
||||
REDUCE_TORCH_TYPE = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
@@ -67,6 +57,7 @@ def _fwd_kernel_stage1(
|
||||
cur_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
reduce_dtype = Att_Out.dtype.element_ty
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
@@ -85,7 +76,7 @@ def _fwd_kernel_stage1(
|
||||
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
||||
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE)
|
||||
q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
|
||||
offs_n_new = cur_batch_start_index + offs_n
|
||||
k_loc = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
||||
@@ -101,7 +92,7 @@ def _fwd_kernel_stage1(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
|
||||
other=0.0,
|
||||
).to(REDUCE_TRITON_TYPE)
|
||||
).to(reduce_dtype)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
|
||||
@@ -198,7 +189,7 @@ def _decode_att_m_fwd(
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
Lk = k_buffer.shape[-1]
|
||||
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
|
||||
@@ -308,6 +299,7 @@ def _fwd_grouped_kernel_stage1(
|
||||
cur_kv_head = tl.program_id(1)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
reduce_dtype = Att_Out.dtype.element_ty
|
||||
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
@@ -336,7 +328,7 @@ def _fwd_grouped_kernel_stage1(
|
||||
for start_mark in range(0, block_mask, 1):
|
||||
q = tl.load(
|
||||
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
|
||||
).to(REDUCE_TRITON_TYPE)
|
||||
).to(reduce_dtype)
|
||||
offs_n_new = cur_batch_start_index + offs_n
|
||||
k_loc = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
||||
@@ -352,11 +344,11 @@ def _fwd_grouped_kernel_stage1(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
|
||||
other=0.0,
|
||||
).to(REDUCE_TRITON_TYPE)
|
||||
).to(reduce_dtype)
|
||||
qk = tl.dot(q, k)
|
||||
if BLOCK_DPE > 0:
|
||||
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
|
||||
REDUCE_TRITON_TYPE
|
||||
reduce_dtype
|
||||
)
|
||||
offs_buf_kpe = (
|
||||
k_loc[None, :] * stride_buf_kbs
|
||||
@@ -367,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=offs_n_new[None, :] < cur_batch_end_index,
|
||||
other=0.0,
|
||||
).to(REDUCE_TRITON_TYPE)
|
||||
).to(reduce_dtype)
|
||||
qk += tl.dot(qpe, kpe)
|
||||
qk *= sm_scale
|
||||
|
||||
@@ -477,8 +469,8 @@ def _decode_grouped_att_m_fwd(
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
BLOCK = 64
|
||||
Lk = k_buffer.shape[-1]
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
|
||||
@@ -30,7 +30,13 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Meta-Llama-3.1-70B-Instruc
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8"
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
|
||||
def is_in_ci():
|
||||
"""Return whether it is in CI runner."""
|
||||
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
|
||||
if is_in_ci():
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
|
||||
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
|
||||
else:
|
||||
@@ -547,3 +553,35 @@ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
|
||||
|
||||
assert res["completed"] == num_prompts
|
||||
return res
|
||||
|
||||
|
||||
def run_bench_latency(model, other_args):
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.bench_latency",
|
||||
"--model-path",
|
||||
model,
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--input",
|
||||
"128",
|
||||
"--output",
|
||||
"8",
|
||||
*other_args,
|
||||
]
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
try:
|
||||
stdout, stderr = process.communicate()
|
||||
output = stdout.decode()
|
||||
error = stderr.decode()
|
||||
print(f"Output: {output}", flush=True)
|
||||
print(f"Error: {error}", flush=True)
|
||||
|
||||
lastline = output.split("\n")[-3]
|
||||
output_throughput = float(lastline.split(" ")[-2])
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
|
||||
return output_throughput
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import subprocess
|
||||
import unittest
|
||||
|
||||
@@ -6,77 +5,25 @@ from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||
is_in_ci,
|
||||
run_bench_latency,
|
||||
)
|
||||
|
||||
|
||||
class TestBenchLatency(unittest.TestCase):
|
||||
def test_default(self):
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.bench_latency",
|
||||
"--model-path",
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--input",
|
||||
"128",
|
||||
"--output",
|
||||
"8",
|
||||
]
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
output_throughput = run_bench_latency(DEFAULT_MODEL_NAME_FOR_TEST, [])
|
||||
|
||||
try:
|
||||
stdout, stderr = process.communicate()
|
||||
output = stdout.decode()
|
||||
error = stderr.decode()
|
||||
print(f"Output: {output}")
|
||||
print(f"Error: {error}")
|
||||
|
||||
lastline = output.split("\n")[-3]
|
||||
value = float(lastline.split(" ")[-2])
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
assert value > 130
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
if is_in_ci():
|
||||
assert output_throughput > 130, f"{output_throughput=}"
|
||||
|
||||
def test_moe_default(self):
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.bench_latency",
|
||||
"--model",
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--input",
|
||||
"128",
|
||||
"--output",
|
||||
"8",
|
||||
"--tp",
|
||||
"2",
|
||||
]
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
output_throughput = run_bench_latency(
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"]
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = process.communicate()
|
||||
output = stdout.decode()
|
||||
error = stderr.decode()
|
||||
print(f"Output: {output}")
|
||||
print(f"Error: {error}")
|
||||
|
||||
lastline = output.split("\n")[-3]
|
||||
value = float(lastline.split(" ")[-2])
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
assert value > 125
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
if is_in_ci():
|
||||
assert output_throughput > 125, f"{output_throughput=}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||
is_in_ci,
|
||||
run_bench_serving,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
other_server_args=[],
|
||||
)
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 2600
|
||||
|
||||
def test_offline_throughput_without_radix_cache(self):
|
||||
@@ -29,7 +29,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
other_server_args=["--disable-radix-cache"],
|
||||
)
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 2800
|
||||
|
||||
def test_offline_throughput_without_chunked_prefill(self):
|
||||
@@ -40,7 +40,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
other_server_args=["--chunked-prefill-size", "-1"],
|
||||
)
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 2600
|
||||
|
||||
def test_offline_throughput_with_triton_attention_backend(self):
|
||||
@@ -56,7 +56,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 2600
|
||||
|
||||
def test_online_latency_default(self):
|
||||
@@ -67,7 +67,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
other_server_args=[],
|
||||
)
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
if is_in_ci():
|
||||
assert res["median_e2e_latency_ms"] < 12000
|
||||
assert res["median_ttft_ms"] < 80
|
||||
assert res["median_itl_ms"] < 12
|
||||
@@ -80,7 +80,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
other_server_args=["--tp", "2"],
|
||||
)
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 1850
|
||||
|
||||
def test_moe_offline_throughput_without_radix_cache(self):
|
||||
@@ -91,7 +91,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
other_server_args=["--tp", "2", "--disable-radix-cache"],
|
||||
)
|
||||
|
||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 1950
|
||||
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.625, f"{metrics}"
|
||||
assert metrics["score"] >= 0.62, f"{metrics}"
|
||||
|
||||
def test_human_eval(self):
|
||||
args = SimpleNamespace(
|
||||
@@ -54,7 +54,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.425, f"{metrics}"
|
||||
assert metrics["score"] >= 0.42, f"{metrics}"
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.625, f"{metrics}"
|
||||
assert metrics["score"] >= 0.62, f"{metrics}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import subprocess
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
@@ -7,37 +8,49 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_latency,
|
||||
)
|
||||
|
||||
|
||||
class TestTritonAttnBackend(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
def test_latency(self):
|
||||
output_throughput = run_bench_latency(
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
[
|
||||
"--attention-backend",
|
||||
"triton",
|
||||
"--enable-torch-compile",
|
||||
],
|
||||
)
|
||||
|
||||
if is_in_ci():
|
||||
assert output_throughput > 155, f"{output_throughput=}"
|
||||
|
||||
def test_mmlu(self):
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--attention-backend", "triton"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user