[EAGLE] many fixes for eagle (#4195)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -18,12 +18,15 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
runtime_common = [
|
runtime_common = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
"datasets",
|
||||||
"decord",
|
"decord",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"hf_transfer",
|
"hf_transfer",
|
||||||
"huggingface_hub",
|
"huggingface_hub",
|
||||||
"interegular",
|
"interegular",
|
||||||
|
"llguidance>=0.6.15",
|
||||||
"modelscope",
|
"modelscope",
|
||||||
|
"ninja",
|
||||||
"orjson",
|
"orjson",
|
||||||
"packaging",
|
"packaging",
|
||||||
"pillow",
|
"pillow",
|
||||||
@@ -33,13 +36,10 @@ runtime_common = [
|
|||||||
"python-multipart",
|
"python-multipart",
|
||||||
"pyzmq>=25.1.2",
|
"pyzmq>=25.1.2",
|
||||||
"torchao>=0.7.0",
|
"torchao>=0.7.0",
|
||||||
|
"transformers==4.48.3",
|
||||||
"uvicorn",
|
"uvicorn",
|
||||||
"uvloop",
|
"uvloop",
|
||||||
"xgrammar==0.1.14",
|
"xgrammar==0.1.14",
|
||||||
"ninja",
|
|
||||||
"transformers==4.48.3",
|
|
||||||
"llguidance>=0.6.15",
|
|
||||||
"datasets"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class ModelConfig:
|
|||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
if context_length > derived_context_len:
|
if context_length > derived_context_len:
|
||||||
if get_bool_env_var(
|
if get_bool_env_var(
|
||||||
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
|
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||||
|
|||||||
@@ -106,6 +106,8 @@ class Engine:
|
|||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
||||||
server_args=server_args
|
server_args=server_args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.server_args = server_args
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
self.scheduler_info = scheduler_info
|
self.scheduler_info = scheduler_info
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ class Sampler(nn.Module):
|
|||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
top_logprobs_nums: List[int],
|
top_logprobs_nums: List[int],
|
||||||
token_ids_logprobs: List[List[int]],
|
token_ids_logprobs: List[List[int]],
|
||||||
batch_next_token_ids: Optional[torch.Tensor] = None,
|
|
||||||
):
|
):
|
||||||
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
||||||
|
|
||||||
@@ -72,7 +71,6 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
if sampling_info.is_all_greedy:
|
if sampling_info.is_all_greedy:
|
||||||
# Use torch.argmax if all requests use greedy sampling
|
# Use torch.argmax if all requests use greedy sampling
|
||||||
if batch_next_token_ids is None:
|
|
||||||
batch_next_token_ids = torch.argmax(logits, -1)
|
batch_next_token_ids = torch.argmax(logits, -1)
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
@@ -94,7 +92,6 @@ class Sampler(nn.Module):
|
|||||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||||
).clamp(min=torch.finfo(probs.dtype).min)
|
).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
|
|
||||||
if batch_next_token_ids is None:
|
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
@@ -116,12 +113,9 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
if self.use_nan_detection and not torch.all(success):
|
if self.use_nan_detection and not torch.all(success):
|
||||||
logger.warning("Detected errors during sampling!")
|
logger.warning("Detected errors during sampling!")
|
||||||
batch_next_token_ids = torch.zeros_like(
|
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||||
batch_next_token_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||||
if batch_next_token_ids is None:
|
|
||||||
# A slower fallback implementation with torch native operations.
|
# A slower fallback implementation with torch native operations.
|
||||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
probs,
|
probs,
|
||||||
|
|||||||
@@ -957,11 +957,13 @@ class Scheduler:
|
|||||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
|
# Filter batch
|
||||||
last_bs = self.last_batch.batch_size()
|
last_bs = self.last_batch.batch_size()
|
||||||
self.last_batch.filter_batch()
|
self.last_batch.filter_batch()
|
||||||
if self.last_batch.batch_size() < last_bs:
|
if self.last_batch.batch_size() < last_bs:
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
|
# Merge the new batch into the running batch
|
||||||
if not self.last_batch.is_empty():
|
if not self.last_batch.is_empty():
|
||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
self.running_batch = self.last_batch
|
self.running_batch = self.last_batch
|
||||||
|
|||||||
@@ -300,10 +300,11 @@ class CudaGraphRunner:
|
|||||||
def capture(self):
|
def capture(self):
|
||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
self.stream = graph_capture_context.stream
|
self.stream = graph_capture_context.stream
|
||||||
|
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||||
capture_range = (
|
capture_range = (
|
||||||
tqdm.tqdm(self.capture_bs)
|
tqdm.tqdm(reversed(self.capture_bs))
|
||||||
if get_tensor_model_parallel_rank() == 0
|
if get_tensor_model_parallel_rank() == 0
|
||||||
else self.capture_bs
|
else reversed(self.capture_bs)
|
||||||
)
|
)
|
||||||
for bs in capture_range:
|
for bs in capture_range:
|
||||||
with patch_model(
|
with patch_model(
|
||||||
|
|||||||
@@ -928,45 +928,6 @@ class ModelRunner:
|
|||||||
sampling_info.update_regex_vocab_mask()
|
sampling_info.update_regex_vocab_mask()
|
||||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||||
|
|
||||||
def update_output_logprobs(
|
|
||||||
self,
|
|
||||||
logits_output: LogitsProcessorOutput,
|
|
||||||
sampling_info: SamplingBatchInfo,
|
|
||||||
top_logprobs_nums: List[int],
|
|
||||||
token_ids_logprobs: List[int],
|
|
||||||
next_token_ids: torch.Tensor,
|
|
||||||
*,
|
|
||||||
num_tokens_per_req: List[int],
|
|
||||||
):
|
|
||||||
"""Update the logits_output's output logprob based on next_token_ids
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits_output: The logits output from the model forward
|
|
||||||
sampling_info: Sampling info for logprob calculation
|
|
||||||
top_logprobs_nums: Number of logprobs per request.
|
|
||||||
next_token_ids: Next token ids.
|
|
||||||
num_tokens_per_req: The number of tokens per request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of next_token_ids
|
|
||||||
"""
|
|
||||||
self._preprocess_logits(logits_output, sampling_info)
|
|
||||||
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
|
||||||
top_logprobs_nums_repeat_interleaved = []
|
|
||||||
token_ids_logprobs_repeat_interleaved = []
|
|
||||||
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
|
||||||
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
|
||||||
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
|
||||||
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
|
||||||
self.sampler(
|
|
||||||
logits_output,
|
|
||||||
sampling_info,
|
|
||||||
True,
|
|
||||||
top_logprobs_nums_repeat_interleaved,
|
|
||||||
token_ids_logprobs_repeat_interleaved,
|
|
||||||
batch_next_token_ids=next_token_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits_output: LogitsProcessorOutput,
|
logits_output: LogitsProcessorOutput,
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||||
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
|
|
||||||
self.frequency_penalties = torch.cat(
|
self.frequency_penalties = torch.cat(
|
||||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||||
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
|
|
||||||
self.presence_penalties = torch.cat(
|
self.presence_penalties = torch.cat(
|
||||||
[self.presence_penalties, their.presence_penalties], dim=0
|
[self.presence_penalties, their.presence_penalties], dim=0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
@@ -302,13 +303,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Set inputs
|
# Set inputs
|
||||||
forward_batch.input_ids = input_ids
|
forward_batch.input_ids = input_ids
|
||||||
|
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
|
||||||
forward_batch.out_cache_loc = out_cache_loc[
|
forward_batch.out_cache_loc = out_cache_loc[
|
||||||
forward_batch.batch_size
|
:, self.topk * i : self.topk * (i + 1)
|
||||||
* self.topk
|
].flatten()
|
||||||
* i : forward_batch.batch_size
|
|
||||||
* self.topk
|
|
||||||
* (i + 1)
|
|
||||||
]
|
|
||||||
forward_batch.positions.add_(1)
|
forward_batch.positions.add_(1)
|
||||||
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||||
spec_info.hidden_states = hidden_states
|
spec_info.hidden_states = hidden_states
|
||||||
@@ -353,31 +351,61 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.spec_info = res.draft_input
|
batch.spec_info = res.draft_input
|
||||||
|
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
# Compute output logprobs using the sampler.
|
self.add_logprob_values(batch, res, logits_output)
|
||||||
num_tokens_per_req = [
|
|
||||||
accept + 1 for accept in res.accept_length_per_req_cpu
|
return logits_output, res, model_worker_batch
|
||||||
]
|
|
||||||
self.target_worker.model_runner.update_output_logprobs(
|
def add_logprob_values(
|
||||||
logits_output,
|
self,
|
||||||
batch.sampling_info,
|
batch: ScheduleBatch,
|
||||||
batch.top_logprobs_nums,
|
res: EagleVerifyOutput,
|
||||||
batch.token_ids_logprobs,
|
logits_output: LogitsProcessorOutput,
|
||||||
res.verified_id,
|
):
|
||||||
# +1 for bonus token.
|
# Extract args
|
||||||
num_tokens_per_req=num_tokens_per_req,
|
logits_output = res.logits_output
|
||||||
|
top_logprobs_nums = batch.top_logprobs_nums
|
||||||
|
token_ids_logprobs = batch.token_ids_logprobs
|
||||||
|
logprobs = torch.nn.functional.log_softmax(
|
||||||
|
logits_output.next_token_logits, dim=-1
|
||||||
)
|
)
|
||||||
|
batch_next_token_ids = res.verified_id
|
||||||
|
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
||||||
|
|
||||||
|
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
||||||
|
top_logprobs_nums_repeat_interleaved = []
|
||||||
|
token_ids_logprobs_repeat_interleaved = []
|
||||||
|
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
||||||
|
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
||||||
|
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
||||||
|
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
||||||
|
|
||||||
|
# Extract logprobs
|
||||||
|
if any(x > 0 for x in top_logprobs_nums):
|
||||||
|
(
|
||||||
|
logits_output.next_token_top_logprobs_val,
|
||||||
|
logits_output.next_token_top_logprobs_idx,
|
||||||
|
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
|
||||||
|
|
||||||
|
if any(x is not None for x in token_ids_logprobs):
|
||||||
|
(
|
||||||
|
logits_output.next_token_token_ids_logprobs_val,
|
||||||
|
logits_output.next_token_token_ids_logprobs_idx,
|
||||||
|
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
|
||||||
|
|
||||||
|
logits_output.next_token_logprobs = logprobs[
|
||||||
|
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
||||||
|
batch_next_token_ids,
|
||||||
|
]
|
||||||
|
|
||||||
# Add output logprobs to the request.
|
# Add output logprobs to the request.
|
||||||
pt = 0
|
pt = 0
|
||||||
# NOTE: tolist() of these values are skipped when output is processed
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||||
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
|
verified_ids = batch_next_token_ids.tolist()
|
||||||
verified_ids = res.verified_id.tolist()
|
|
||||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
||||||
for _ in range(num_tokens):
|
for _ in range(num_tokens):
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
token_id = verified_ids[pt]
|
|
||||||
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||||
req.output_token_logprobs_idx.append(token_id)
|
req.output_token_logprobs_idx.append(verified_ids[pt])
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs_val.append(
|
req.output_top_logprobs_val.append(
|
||||||
res.logits_output.next_token_top_logprobs_val[pt]
|
res.logits_output.next_token_top_logprobs_val[pt]
|
||||||
@@ -387,8 +415,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
pt += 1
|
pt += 1
|
||||||
|
|
||||||
return logits_output, res, model_worker_batch
|
|
||||||
|
|
||||||
def forward_draft_extend(
|
def forward_draft_extend(
|
||||||
self,
|
self,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ class TestSRTBackend(unittest.TestCase):
|
|||||||
# Run twice to capture more bugs
|
# Run twice to capture more bugs
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
accuracy, latency = test_hellaswag_select()
|
accuracy, latency = test_hellaswag_select()
|
||||||
self.assertGreater(accuracy, 0.65)
|
self.assertGreater(accuracy, 0.60)
|
||||||
|
|
||||||
def test_gen_min_new_tokens(self):
|
def test_gen_min_new_tokens(self):
|
||||||
test_gen_min_new_tokens()
|
test_gen_min_new_tokens()
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
def _test_acc_length(self, engine):
|
def _test_acc_length(self, engine):
|
||||||
prompt = [
|
prompt = [
|
||||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
||||||
]
|
] * 5
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
||||||
output = engine.generate(prompt, sampling_params)
|
output = engine.generate(prompt, sampling_params)
|
||||||
output = output[0]
|
output = output[0]
|
||||||
@@ -141,10 +141,14 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
/ output["meta_info"]["e2e_latency"]
|
/ output["meta_info"]["e2e_latency"]
|
||||||
)
|
)
|
||||||
print(f"{acc_length=}")
|
print(f"{acc_length=}")
|
||||||
|
|
||||||
|
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
|
||||||
self.assertGreater(acc_length, 3.6)
|
self.assertGreater(acc_length, 3.6)
|
||||||
|
else:
|
||||||
|
self.assertGreater(acc_length, 2.6)
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLEEngineTokenMap(unittest.TestCase):
|
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
||||||
BASE_CONFIG = {
|
BASE_CONFIG = {
|
||||||
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
|
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
|
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
|
||||||
@@ -155,6 +159,7 @@ class TestEAGLEEngineTokenMap(unittest.TestCase):
|
|||||||
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
|
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
|
||||||
"mem_fraction_static": 0.7,
|
"mem_fraction_static": 0.7,
|
||||||
"cuda_graph_max_bs": 5,
|
"cuda_graph_max_bs": 5,
|
||||||
|
"dtype": "float16",
|
||||||
}
|
}
|
||||||
NUM_CONFIGS = 1
|
NUM_CONFIGS = 1
|
||||||
|
|
||||||
@@ -245,8 +250,25 @@ class TestEAGLEServer(unittest.TestCase):
|
|||||||
for p in threads:
|
for p in threads:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
|
def test_max_token_one(self):
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=1,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Just run and check it does not hang
|
||||||
|
metrics = run_eval(args)
|
||||||
|
self.assertGreater(metrics["output_throughput"], 50)
|
||||||
|
|
||||||
def test_gsm8k(self):
|
def test_gsm8k(self):
|
||||||
server_info = requests.get(self.base_url + "/flush_cache")
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
num_shots=5,
|
num_shots=5,
|
||||||
@@ -391,6 +413,53 @@ class TestEAGLEServer(unittest.TestCase):
|
|||||||
with ThreadPoolExecutor(8) as executor:
|
with ThreadPoolExecutor(8) as executor:
|
||||||
list(executor.map(func, args))
|
list(executor.map(func, args))
|
||||||
|
|
||||||
|
def run_decode(self, sampling_params):
|
||||||
|
return_logprob = True
|
||||||
|
top_logprobs_num = 5
|
||||||
|
return_text = True
|
||||||
|
n = 1
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "Human: Write a travel blog post to Hawaii.\n\nAssistant:",
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 48,
|
||||||
|
"n": n,
|
||||||
|
"temperature": 0.7,
|
||||||
|
**sampling_params,
|
||||||
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"return_text_in_logprobs": return_text,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
def test_penalty_mixed(self):
|
||||||
|
args = [
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{"frequency_penalty": 2},
|
||||||
|
{"presence_penalty": 1},
|
||||||
|
{"min_new_tokens": 16},
|
||||||
|
{"frequency_penalty": 0.2},
|
||||||
|
{"presence_penalty": 0.4},
|
||||||
|
{"min_new_tokens": 8},
|
||||||
|
{"frequency_penalty": 0.4, "presence_penalty": 0.8},
|
||||||
|
{"frequency_penalty": 0.4, "min_new_tokens": 12},
|
||||||
|
{"presence_penalty": 0.8, "min_new_tokens": 12},
|
||||||
|
{"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32},
|
||||||
|
{"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32},
|
||||||
|
]
|
||||||
|
random.shuffle(args * 5)
|
||||||
|
with ThreadPoolExecutor(8) as executor:
|
||||||
|
list(executor.map(self.run_decode, args))
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLERetract(TestEAGLEServer):
|
class TestEAGLERetract(TestEAGLEServer):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -44,11 +44,12 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.71)
|
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
|
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
|
||||||
|
|
||||||
|
self.assertGreater(metrics["score"], 0.71)
|
||||||
|
|
||||||
def test_human_eval(self):
|
def test_human_eval(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -59,13 +60,14 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.64)
|
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
write_github_step_summary(
|
write_github_step_summary(
|
||||||
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
|
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["score"], 0.64)
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
def test_mgsm_en(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -76,13 +78,14 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.835)
|
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
write_github_step_summary(
|
write_github_step_summary(
|
||||||
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
|
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["score"], 0.835)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
@@ -129,6 +130,8 @@ class TestDeepseekV3MTP(unittest.TestCase):
|
|||||||
kill_process_tree(cls.process.pid)
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
def test_gsm8k(self):
|
def test_gsm8k(self):
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
num_shots=5,
|
num_shots=5,
|
||||||
data_path=None,
|
data_path=None,
|
||||||
@@ -143,6 +146,11 @@ class TestDeepseekV3MTP(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertGreater(metrics["accuracy"], 0.60)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||||
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class TestPenalty(unittest.TestCase):
|
|||||||
# prompt that is supposed to generate < 32 tokens
|
# prompt that is supposed to generate < 32 tokens
|
||||||
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 48,
|
||||||
"n": n,
|
"n": n,
|
||||||
**sampling_params,
|
**sampling_params,
|
||||||
},
|
},
|
||||||
@@ -68,19 +68,22 @@ class TestPenalty(unittest.TestCase):
|
|||||||
def test_presence_penalty(self):
|
def test_presence_penalty(self):
|
||||||
self.run_decode({"presence_penalty": 2})
|
self.run_decode({"presence_penalty": 2})
|
||||||
|
|
||||||
def test_mixed(self):
|
def test_penalty_mixed(self):
|
||||||
args = [
|
args = [
|
||||||
{},
|
{},
|
||||||
{},
|
{},
|
||||||
{},
|
{},
|
||||||
{"frequency_penalty": 2},
|
{"frequency_penalty": 2},
|
||||||
{"min_new_tokens": 16},
|
|
||||||
{"presence_penalty": 1},
|
{"presence_penalty": 1},
|
||||||
|
{"min_new_tokens": 16},
|
||||||
{"frequency_penalty": 0.2},
|
{"frequency_penalty": 0.2},
|
||||||
{"min_new_tokens": 8},
|
|
||||||
{"presence_penalty": 0.4},
|
{"presence_penalty": 0.4},
|
||||||
{"presence_penalty": 0.4, "frequency_penalty": 2},
|
{"min_new_tokens": 8},
|
||||||
{"min_new_tokens": 12, "frequency_penalty": 2},
|
{"frequency_penalty": 0.4, "presence_penalty": 0.8},
|
||||||
|
{"frequency_penalty": 0.4, "min_new_tokens": 12},
|
||||||
|
{"presence_penalty": 0.8, "min_new_tokens": 12},
|
||||||
|
{"presence_penalty": -0.3, "frequency_penalty": 1.3, "min_new_tokens": 32},
|
||||||
|
{"presence_penalty": 0.3, "frequency_penalty": -1.3, "min_new_tokens": 32},
|
||||||
]
|
]
|
||||||
random.shuffle(args * 5)
|
random.shuffle(args * 5)
|
||||||
with ThreadPoolExecutor(8) as executor:
|
with ThreadPoolExecutor(8) as executor:
|
||||||
|
|||||||
Reference in New Issue
Block a user