Simplify eagle tests and TP sync in grammar backend (#4066)
This commit is contained in:
@@ -1886,33 +1886,22 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if self.server_args.enable_dp_attention:
|
if self.server_args.enable_dp_attention:
|
||||||
if self.attn_tp_size > 1:
|
tp_size = self.attn_tp_size
|
||||||
# Sync across attn TP ranks to make sure they have the same number of ready requests
|
tp_group = self.attn_tp_cpu_group
|
||||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
tensor,
|
|
||||||
op=torch.distributed.ReduceOp.MAX,
|
|
||||||
group=self.attn_tp_cpu_group,
|
|
||||||
)
|
|
||||||
num_ready_reqs_max = tensor.item()
|
|
||||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
|
||||||
self.grammar_queue[i].grammar = self.grammar_queue[
|
|
||||||
i
|
|
||||||
].grammar.result()
|
|
||||||
num_ready_reqs = num_ready_reqs_max
|
|
||||||
else:
|
else:
|
||||||
if self.tp_size > 1:
|
tp_size = self.tp_size
|
||||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
tp_group = self.tp_cpu_group
|
||||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
|
||||||
torch.distributed.all_reduce(
|
if tp_size > 1:
|
||||||
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||||
)
|
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||||
num_ready_reqs_max = tensor.item()
|
torch.distributed.all_reduce(
|
||||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
||||||
self.grammar_queue[i].grammar = self.grammar_queue[
|
)
|
||||||
i
|
num_ready_reqs_max = tensor.item()
|
||||||
].grammar.result()
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||||
num_ready_reqs = num_ready_reqs_max
|
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
||||||
|
num_ready_reqs = num_ready_reqs_max
|
||||||
|
|
||||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||||
|
|||||||
@@ -31,16 +31,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_token_map(token_map_path: str) -> List[int]:
|
|
||||||
if not os.path.exists(token_map_path):
|
|
||||||
cache_dir = snapshot_download(
|
|
||||||
os.path.dirname(token_map_path),
|
|
||||||
ignore_patterns=["*.bin", "*.safetensors"],
|
|
||||||
)
|
|
||||||
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
|
||||||
return torch.load(token_map_path)
|
|
||||||
|
|
||||||
|
|
||||||
class EAGLEWorker(TpModelWorker):
|
class EAGLEWorker(TpModelWorker):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -57,6 +47,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||||
server_args.disable_cuda_graph = True
|
server_args.disable_cuda_graph = True
|
||||||
|
|
||||||
|
# Load hot token ids
|
||||||
if server_args.speculative_token_map is not None:
|
if server_args.speculative_token_map is not None:
|
||||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||||
server_args.json_model_override_args = (
|
server_args.json_model_override_args = (
|
||||||
@@ -65,6 +56,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
else:
|
else:
|
||||||
self.hot_token_id = None
|
self.hot_token_id = None
|
||||||
|
|
||||||
|
# Init target worker
|
||||||
super().__init__(
|
super().__init__(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
@@ -88,9 +80,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
if self.hot_token_id is not None:
|
if self.hot_token_id is not None:
|
||||||
head = head.clone()
|
head = head.clone()
|
||||||
self.hot_token_id = torch.tensor(
|
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||||
self.hot_token_id, dtype=torch.int32, device=head.device
|
|
||||||
)
|
|
||||||
head.data = head.data[self.hot_token_id]
|
head.data = head.data[self.hot_token_id]
|
||||||
self.model_runner.model.set_embed_and_head(embed, head)
|
self.model_runner.model.set_embed_and_head(embed, head)
|
||||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||||
@@ -369,3 +359,14 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
][:req_len]
|
][:req_len]
|
||||||
self.model_runner.token_to_kv_pool.free(kv_indices)
|
self.model_runner.token_to_kv_pool.free(kv_indices)
|
||||||
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def load_token_map(token_map_path: str) -> List[int]:
|
||||||
|
if not os.path.exists(token_map_path):
|
||||||
|
cache_dir = snapshot_download(
|
||||||
|
os.path.dirname(token_map_path),
|
||||||
|
ignore_patterns=["*.bin", "*.safetensors"],
|
||||||
|
)
|
||||||
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
||||||
|
hot_token_id = torch.load(token_map_path)
|
||||||
|
return torch.tensor(hot_token_id, dtype=torch.int32)
|
||||||
|
|||||||
@@ -501,6 +501,7 @@ def get_benchmark_args(
|
|||||||
request_rate=float("inf"),
|
request_rate=float("inf"),
|
||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
|
seed: int = 0,
|
||||||
pd_seperated: bool = False,
|
pd_seperated: bool = False,
|
||||||
):
|
):
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
@@ -524,7 +525,7 @@ def get_benchmark_args(
|
|||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
disable_stream=disable_stream,
|
disable_stream=disable_stream,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
seed=0,
|
seed=seed,
|
||||||
disable_ignore_eos=disable_ignore_eos,
|
disable_ignore_eos=disable_ignore_eos,
|
||||||
extra_request_body=None,
|
extra_request_body=None,
|
||||||
apply_chat_template=False,
|
apply_chat_template=False,
|
||||||
@@ -549,6 +550,7 @@ def run_bench_serving(
|
|||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
need_warmup=False,
|
need_warmup=False,
|
||||||
|
seed: int = 0,
|
||||||
):
|
):
|
||||||
# Launch the server
|
# Launch the server
|
||||||
base_url = DEFAULT_URL_FOR_TEST
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
@@ -572,6 +574,7 @@ def run_bench_serving(
|
|||||||
request_rate=request_rate,
|
request_rate=request_rate,
|
||||||
disable_stream=disable_stream,
|
disable_stream=disable_stream,
|
||||||
disable_ignore_eos=disable_ignore_eos,
|
disable_ignore_eos=disable_ignore_eos,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import unittest
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from utils import *
|
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
|
||||||
|
|
||||||
from sglang.test.runners import HFRunner, SRTRunner
|
from sglang.test.runners import HFRunner, SRTRunner
|
||||||
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
|
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
|
||||||
|
|||||||
@@ -13,15 +13,13 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from utils import *
|
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
|
||||||
|
|
||||||
from sglang.test.runners import HFRunner, SRTRunner
|
from sglang.test.test_utils import is_in_ci
|
||||||
from sglang.test.test_utils import calculate_rouge_l, is_in_ci
|
|
||||||
|
|
||||||
MULTI_LORA_MODELS = [
|
MULTI_LORA_MODELS = [
|
||||||
LoRAModelCase(
|
LoRAModelCase(
|
||||||
|
|||||||
@@ -136,8 +136,8 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
def test_online_latency_eagle(self):
|
def test_online_latency_eagle(self):
|
||||||
res = run_bench_serving(
|
res = run_bench_serving(
|
||||||
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
num_prompts=50,
|
num_prompts=300,
|
||||||
request_rate=1,
|
request_rate=8,
|
||||||
sharegpt_context_len=3072,
|
sharegpt_context_len=3072,
|
||||||
disable_ignore_eos=True,
|
disable_ignore_eos=True,
|
||||||
dataset_name="sharegpt",
|
dataset_name="sharegpt",
|
||||||
@@ -156,6 +156,7 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
"0.7",
|
"0.7",
|
||||||
],
|
],
|
||||||
need_warmup=True,
|
need_warmup=True,
|
||||||
|
seed=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
@@ -164,8 +165,8 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
|
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
|
||||||
f'accept_length : {res["accept_length"]:.2f} \n'
|
f'accept_length : {res["accept_length"]:.2f} \n'
|
||||||
)
|
)
|
||||||
self.assertLess(res["median_e2e_latency_ms"], 700)
|
self.assertLess(res["median_e2e_latency_ms"], 1100)
|
||||||
self.assertGreater(res["accept_length"], 2.50)
|
self.assertGreater(res["accept_length"], 3.0)
|
||||||
|
|
||||||
def test_moe_offline_throughput_default(self):
|
def test_moe_offline_throughput_default(self):
|
||||||
res = run_bench_serving(
|
res = run_bench_serving(
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
|
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
|
||||||
ref_engine.shutdown()
|
ref_engine.shutdown()
|
||||||
|
|
||||||
def test_eagle_accuracy(self):
|
def test_correctness(self):
|
||||||
configs = [
|
configs = [
|
||||||
self.BASE_CONFIG,
|
self.BASE_CONFIG,
|
||||||
{**self.BASE_CONFIG, "disable_cuda_graph": True},
|
{**self.BASE_CONFIG, "disable_cuda_graph": True},
|
||||||
@@ -95,67 +95,6 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLEEngineTokenMap(unittest.TestCase):
|
|
||||||
BASE_CONFIG = {
|
|
||||||
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
|
|
||||||
"speculative_draft_model_path": "lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B",
|
|
||||||
"speculative_algorithm": "EAGLE",
|
|
||||||
"speculative_num_steps": 5,
|
|
||||||
"speculative_eagle_topk": 8,
|
|
||||||
"speculative_num_draft_tokens": 64,
|
|
||||||
"mem_fraction_static": 0.7,
|
|
||||||
"cuda_graph_max_bs": 4,
|
|
||||||
"dtype": "float16",
|
|
||||||
}
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.prompt = "Today is a sunny day and I like"
|
|
||||||
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
|
||||||
|
|
||||||
ref_engine = sgl.Engine(model_path=self.BASE_CONFIG["model_path"])
|
|
||||||
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
|
|
||||||
ref_engine.shutdown()
|
|
||||||
|
|
||||||
def test_token_map_accuracy(self):
|
|
||||||
configs = [
|
|
||||||
self.BASE_CONFIG,
|
|
||||||
{
|
|
||||||
**self.BASE_CONFIG,
|
|
||||||
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
for config in configs:
|
|
||||||
print("testing config: ", config)
|
|
||||||
with self.subTest(cuda_graph="enabled"):
|
|
||||||
engine = sgl.Engine(**config)
|
|
||||||
try:
|
|
||||||
self._test_basic_generation(engine)
|
|
||||||
self._test_batch_generation(engine)
|
|
||||||
finally:
|
|
||||||
engine.shutdown()
|
|
||||||
|
|
||||||
def _test_basic_generation(self, engine):
|
|
||||||
output = engine.generate(self.prompt, self.sampling_params)["text"]
|
|
||||||
print(f"{output=}, {self.ref_output=}")
|
|
||||||
self.assertEqual(output, self.ref_output)
|
|
||||||
|
|
||||||
def _test_batch_generation(self, engine):
|
|
||||||
prompts = [
|
|
||||||
"Hello, my name is",
|
|
||||||
"The president of the United States is",
|
|
||||||
"The capital of France is",
|
|
||||||
"The future of AI is",
|
|
||||||
]
|
|
||||||
params = {"temperature": 0, "max_new_tokens": 30}
|
|
||||||
|
|
||||||
outputs = engine.generate(prompts, params)
|
|
||||||
for prompt, output in zip(prompts, outputs):
|
|
||||||
print(f"Prompt: {prompt}")
|
|
||||||
print(f"Generated: {output['text']}")
|
|
||||||
print("-" * 40)
|
|
||||||
|
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
|
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
|
||||||
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
|
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
|
||||||
@@ -222,7 +161,7 @@ class TestEAGLEServer(unittest.TestCase):
|
|||||||
"max_new_tokens": 1024,
|
"max_new_tokens": 1024,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
# set timeout = 1s,mock disconnected
|
# set timeout = 1s, mock disconnected
|
||||||
requests.post(url, json=data, timeout=1)
|
requests.post(url, json=data, timeout=1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
@@ -273,18 +212,71 @@ class TestEAGLEServerTriton(TestEAGLEServer):
|
|||||||
"--speculative-num-steps",
|
"--speculative-num-steps",
|
||||||
"5",
|
"5",
|
||||||
"--speculative-eagle-topk",
|
"--speculative-eagle-topk",
|
||||||
"8",
|
"4",
|
||||||
"--speculative-num-draft-tokens",
|
"--speculative-num-draft-tokens",
|
||||||
"64",
|
"8",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
"0.7",
|
"0.7",
|
||||||
"--attention-backend",
|
"--attention-backend",
|
||||||
"triton",
|
"triton",
|
||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
"32",
|
"16",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEAGLEEngineTokenMap(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.prompt = "Today is a sunny day and I like"
|
||||||
|
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||||
|
|
||||||
|
ref_engine = sgl.Engine(
|
||||||
|
model_path="meta-llama/Meta-Llama-3-8B-Instruct", cuda_graph_max_bs=2
|
||||||
|
)
|
||||||
|
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
|
||||||
|
ref_engine.shutdown()
|
||||||
|
|
||||||
|
def test_correctness(self):
|
||||||
|
config = {
|
||||||
|
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
|
||||||
|
"speculative_algorithm": "EAGLE",
|
||||||
|
"speculative_num_steps": 5,
|
||||||
|
"speculative_eagle_topk": 4,
|
||||||
|
"speculative_num_draft_tokens": 8,
|
||||||
|
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
|
||||||
|
"mem_fraction_static": 0.7,
|
||||||
|
"cuda_graph_max_bs": 4,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = sgl.Engine(**config)
|
||||||
|
try:
|
||||||
|
self._test_basic_generation(engine)
|
||||||
|
self._test_batch_generation(engine)
|
||||||
|
finally:
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
def _test_basic_generation(self, engine):
|
||||||
|
output = engine.generate(self.prompt, self.sampling_params)["text"]
|
||||||
|
print(f"{output=}, {self.ref_output=}")
|
||||||
|
self.assertEqual(output, self.ref_output)
|
||||||
|
|
||||||
|
def _test_batch_generation(self, engine):
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
params = {"temperature": 0, "max_new_tokens": 30}
|
||||||
|
|
||||||
|
outputs = engine.generate(prompts, params)
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
print(f"Prompt: {prompt}")
|
||||||
|
print(f"Generated: {output['text']}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class TestGGUF(unittest.TestCase):
|
|||||||
filename="qwen2-1_5b-instruct-q4_k_m.gguf",
|
filename="qwen2-1_5b-instruct-q4_k_m.gguf",
|
||||||
)
|
)
|
||||||
|
|
||||||
engine = sgl.Engine(model_path=model_path, random_seed=42)
|
engine = sgl.Engine(model_path=model_path, random_seed=42, cuda_graph_max_bs=2)
|
||||||
outputs = engine.generate(prompt, sampling_params)["text"]
|
outputs = engine.generate(prompt, sampling_params)["text"]
|
||||||
engine.shutdown()
|
engine.shutdown()
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import torch
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import is_in_ci
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
class TestHiddenState(unittest.TestCase):
|
class TestHiddenState(unittest.TestCase):
|
||||||
def test_return_hidden_states(self):
|
def test_return_hidden_states(self):
|
||||||
prompts = ["Today is", "Today is a sunny day and I like"]
|
prompts = ["Today is", "Today is a sunny day and I like"]
|
||||||
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
input_ids = tokenizer(prompts).input_ids
|
input_ids = tokenizer(prompts).input_ids
|
||||||
|
|
||||||
@@ -80,7 +80,7 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
|
|
||||||
def test_repeatedly_changes_hidden_states(self):
|
def test_repeatedly_changes_hidden_states(self):
|
||||||
prompts = ["Today is", "Today is a sunny day and I like"]
|
prompts = ["Today is", "Today is a sunny day and I like"]
|
||||||
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
input_ids = tokenizer(prompts).input_ids
|
input_ids = tokenizer(prompts).input_ids
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class TestInputEmbeds(unittest.TestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--disable-radix"],
|
other_args=["--disable-radix", "--cuda-graph-max-bs", 4],
|
||||||
)
|
)
|
||||||
cls.texts = [
|
cls.texts = [
|
||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_class(cls, backend: str, disable_overlap: bool):
|
def setup_class(cls, backend: str):
|
||||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
cls.json_schema = json.dumps(
|
cls.json_schema = json.dumps(
|
||||||
@@ -42,9 +42,6 @@ def setup_class(cls, backend: str, disable_overlap: bool):
|
|||||||
backend,
|
backend,
|
||||||
]
|
]
|
||||||
|
|
||||||
if disable_overlap:
|
|
||||||
other_args += ["--disable-overlap-schedule"]
|
|
||||||
|
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
@@ -56,7 +53,7 @@ def setup_class(cls, backend: str, disable_overlap: bool):
|
|||||||
class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
setup_class(cls, backend="outlines", disable_overlap=False)
|
setup_class(cls, backend="outlines")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
@@ -133,5 +130,17 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
|||||||
list(executor.map(self.run_decode, json_schemas))
|
list(executor.map(self.run_decode, json_schemas))
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
setup_class(cls, backend="xgrammar")
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
setup_class(cls, backend="llguidance")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class TestEnableMetrics(unittest.TestCase):
|
|||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=["--enable-metrics"],
|
other_args=["--enable-metrics", "--cuda-graph-max-bs", 2],
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ class TestTritonAttnBackend(unittest.TestCase):
|
|||||||
"--attention-backend",
|
"--attention-backend",
|
||||||
"triton",
|
"triton",
|
||||||
"--enable-torch-compile",
|
"--enable-torch-compile",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
16,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class TestVertexEndpoint(unittest.TestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--cuda-graph-max-bs", 2],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user