Improve benchmark scripts and error message printing (#2922)
This commit is contained in:
@@ -39,14 +39,15 @@ class BenchArgs:
|
|||||||
dataset_path: str = ""
|
dataset_path: str = ""
|
||||||
num_prompts: int = 1000
|
num_prompts: int = 1000
|
||||||
sharegpt_output_len: Optional[int] = None
|
sharegpt_output_len: Optional[int] = None
|
||||||
|
sharegpt_context_len: Optional[int] = None
|
||||||
random_input_len: int = 1024
|
random_input_len: int = 1024
|
||||||
random_output_len: int = 1024
|
random_output_len: int = 1024
|
||||||
random_range_ratio: float = 0.0
|
random_range_ratio: float = 0.0
|
||||||
gen_num_groups: int = 64
|
gsp_num_groups: int = 64
|
||||||
gen_prompts_per_group: int = 16
|
gsp_prompts_per_group: int = 16
|
||||||
gen_system_prompt_len: int = 2048
|
gsp_system_prompt_len: int = 2048
|
||||||
gen_question_len: int = 128
|
gsp_question_len: int = 128
|
||||||
gen_output_len: int = 256
|
gsp_output_len: int = 256
|
||||||
disable_ignore_eos: bool = False
|
disable_ignore_eos: bool = False
|
||||||
extra_request_body: Optional[str] = None
|
extra_request_body: Optional[str] = None
|
||||||
seed: int = 1
|
seed: int = 1
|
||||||
@@ -82,6 +83,12 @@ class BenchArgs:
|
|||||||
default=BenchArgs.sharegpt_output_len,
|
default=BenchArgs.sharegpt_output_len,
|
||||||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sharegpt-context-len",
|
||||||
|
type=int,
|
||||||
|
default=BenchArgs.sharegpt_context_len,
|
||||||
|
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-input-len",
|
"--random-input-len",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -102,35 +109,35 @@ class BenchArgs:
|
|||||||
"used only for random dataset.",
|
"used only for random dataset.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen-num-groups",
|
"--gsp-num-groups",
|
||||||
type=int,
|
type=int,
|
||||||
default=BenchArgs.gen_num_groups,
|
default=BenchArgs.gsp_num_groups,
|
||||||
help="Number of groups with shared prefix, used"
|
help="Number of groups with shared prefix, used"
|
||||||
"only for generate-shared-prefix",
|
"only for generate-shared-prefix",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen-prompts-per-group",
|
"--gsp-prompts-per-group",
|
||||||
type=int,
|
type=int,
|
||||||
default=BenchArgs.gen_prompts_per_group,
|
default=BenchArgs.gsp_prompts_per_group,
|
||||||
help="Number of prompts per group of shared prefix, used"
|
help="Number of prompts per group of shared prefix, used"
|
||||||
"only for generate-shared-prefix",
|
"only for generate-shared-prefix",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen-system-prompt-len",
|
"--gsp-system-prompt-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=BenchArgs.gen_system_prompt_len,
|
default=BenchArgs.gsp_system_prompt_len,
|
||||||
help="System prompt length, used" "only for generate-shared-prefix",
|
help="System prompt length, used" "only for generate-shared-prefix",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen-question-len",
|
"--gsp-question-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=BenchArgs.gen_question_len,
|
default=BenchArgs.gsp_question_len,
|
||||||
help="Question length, used" "only for generate-shared-prefix",
|
help="Question length, used" "only for generate-shared-prefix",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gen-output-len",
|
"--gsp-output-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=BenchArgs.gen_output_len,
|
default=BenchArgs.gsp_output_len,
|
||||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
|
|||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
fixed_output_len=args.sharegpt_output_len,
|
fixed_output_len=args.sharegpt_output_len,
|
||||||
|
context_len=args.sharegpt_context_len,
|
||||||
)
|
)
|
||||||
elif args.dataset_name == "random":
|
elif args.dataset_name == "random":
|
||||||
input_requests = sample_random_requests(
|
input_requests = sample_random_requests(
|
||||||
@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
|
|||||||
)
|
)
|
||||||
elif args.dataset_name == "generated-shared-prefix":
|
elif args.dataset_name == "generated-shared-prefix":
|
||||||
input_requests = sample_generated_shared_prefix_requests(
|
input_requests = sample_generated_shared_prefix_requests(
|
||||||
num_groups=args.gen_num_groups,
|
num_groups=args.gsp_num_groups,
|
||||||
prompts_per_group=args.gen_prompts_per_group,
|
prompts_per_group=args.gsp_prompts_per_group,
|
||||||
system_prompt_len=args.gen_system_prompt_len,
|
system_prompt_len=args.gsp_system_prompt_len,
|
||||||
question_len=args.gen_question_len,
|
question_len=args.gsp_question_len,
|
||||||
output_len=args.gen_output_len,
|
output_len=args.gsp_output_len,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
|
|||||||
num_requests: int,
|
num_requests: int,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
fixed_output_len: Optional[int] = None,
|
fixed_output_len: Optional[int] = None,
|
||||||
|
context_len: Optional[int] = None,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
|
|||||||
output_len = (
|
output_len = (
|
||||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||||
)
|
)
|
||||||
if prompt_len < 4 or output_len < 4:
|
|
||||||
|
if prompt_len < 1 or output_len < 1:
|
||||||
# Prune too short sequences.
|
# Prune too short sequences.
|
||||||
continue
|
continue
|
||||||
if prompt_len > 1024 or (
|
|
||||||
prompt_len + output_len > 2048 and fixed_output_len is None
|
if context_len and prompt_len + output_len > context_len:
|
||||||
):
|
|
||||||
# Prune too long sequences.
|
# Prune too long sequences.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||||
|
|
||||||
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
||||||
@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
|
|||||||
|
|
||||||
# Create a unique cache filename based on the generation parameters
|
# Create a unique cache filename based on the generation parameters
|
||||||
cache_key = (
|
cache_key = (
|
||||||
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
|
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
|
||||||
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
|
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
|
||||||
f"{tokenizer.__class__.__name__}.pkl"
|
f"{tokenizer.__class__.__name__}.pkl"
|
||||||
)
|
)
|
||||||
return cache_dir / cache_key
|
return cache_dir / cache_key
|
||||||
@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sharegpt-context-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-input-len",
|
"--random-input-len",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -1453,38 +1462,6 @@ if __name__ == "__main__":
|
|||||||
help="Append given JSON object to the request payload. You can use this to specify"
|
help="Append given JSON object to the request payload. You can use this to specify"
|
||||||
"additional generate params like sampling params.",
|
"additional generate params like sampling params.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
|
||||||
group.add_argument(
|
|
||||||
"--gen-num-groups",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="Number of system prompt groups for generated-shared-prefix dataset",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--gen-prompts-per-group",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--gen-system-prompt-len",
|
|
||||||
type=int,
|
|
||||||
default=2048,
|
|
||||||
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--gen-question-len",
|
|
||||||
type=int,
|
|
||||||
default=128,
|
|
||||||
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--gen-output-len",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile",
|
"--profile",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -1497,5 +1474,37 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="The name of LoRA adapter",
|
help="The name of LoRA adapter",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||||
|
group.add_argument(
|
||||||
|
"--gsp-num-groups",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Number of system prompt groups for generated-shared-prefix dataset",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--gsp-prompts-per-group",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--gsp-system-prompt-len",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--gsp-question-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--gsp-output-len",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_benchmark(args)
|
run_benchmark(args)
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ class GenerateReqInput:
|
|||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output.
|
# Whether to stream output.
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||||
|
log_metrics: bool = True
|
||||||
|
|
||||||
# The modalities of the image data [image, multi-images, video]
|
# The modalities of the image data [image, multi-images, video]
|
||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
# LoRA related
|
# LoRA related
|
||||||
@@ -196,6 +199,7 @@ class GenerateReqInput:
|
|||||||
top_logprobs_num=self.top_logprobs_num[i],
|
top_logprobs_num=self.top_logprobs_num[i],
|
||||||
return_text_in_logprobs=self.return_text_in_logprobs,
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
||||||
stream=self.stream,
|
stream=self.stream,
|
||||||
|
log_metrics=self.log_metrics,
|
||||||
modalities=self.modalities[i] if self.modalities else None,
|
modalities=self.modalities[i] if self.modalities else None,
|
||||||
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
||||||
)
|
)
|
||||||
@@ -243,6 +247,8 @@ class EmbeddingReqInput:
|
|||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
# Dummy input embeds for compatibility
|
# Dummy input embeds for compatibility
|
||||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||||
|
log_metrics: bool = True
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
if (self.text is None and self.input_ids is None) or (
|
||||||
|
|||||||
@@ -631,7 +631,8 @@ class Scheduler:
|
|||||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Request length is longer than the KV cache pool size or "
|
"Request length is longer than the KV cache pool size or "
|
||||||
"the max context length. Truncated!!!"
|
"the max context length. Truncated. "
|
||||||
|
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
|
||||||
)
|
)
|
||||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
|
|||||||
get_zmq_socket,
|
get_zmq_socket,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
)
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -640,7 +641,9 @@ class TokenizerManager:
|
|||||||
|
|
||||||
self.to_create_loop = False
|
self.to_create_loop = False
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
|
self.asyncio_tasks.add(
|
||||||
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||||
|
)
|
||||||
|
|
||||||
# We cannot add signal handler when the tokenizer manager is not in
|
# We cannot add signal handler when the tokenizer manager is not in
|
||||||
# the main thread due to the CPython limitation.
|
# the main thread due to the CPython limitation.
|
||||||
@@ -653,7 +656,9 @@ class TokenizerManager:
|
|||||||
"not in the main thread. This disables graceful shutdown of the "
|
"not in the main thread. This disables graceful shutdown of the "
|
||||||
"tokenizer manager when SIGTERM is received."
|
"tokenizer manager when SIGTERM is received."
|
||||||
)
|
)
|
||||||
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
|
self.asyncio_tasks.add(
|
||||||
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||||
|
)
|
||||||
|
|
||||||
async def sigterm_watchdog(self):
|
async def sigterm_watchdog(self):
|
||||||
while not self.gracefully_exit:
|
while not self.gracefully_exit:
|
||||||
@@ -738,9 +743,13 @@ class TokenizerManager:
|
|||||||
state.finished = recv_obj.finished_reasons[i] is not None
|
state.finished = recv_obj.finished_reasons[i] is not None
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
if self.enable_metrics:
|
if self.enable_metrics and state.obj.log_metrics:
|
||||||
self.collect_metrics(state, recv_obj, i)
|
self.collect_metrics(state, recv_obj, i)
|
||||||
if self.dump_requests_folder and state.finished:
|
if (
|
||||||
|
self.dump_requests_folder
|
||||||
|
and state.finished
|
||||||
|
and state.obj.log_metrics
|
||||||
|
):
|
||||||
self.dump_requests(state, out_dict)
|
self.dump_requests(state, out_dict)
|
||||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
@@ -887,20 +896,38 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
||||||
|
filename = os.path.join(
|
||||||
|
self.dump_requests_folder,
|
||||||
|
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
||||||
|
)
|
||||||
|
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
|
||||||
|
|
||||||
to_dump = self.dump_request_list
|
to_dump = self.dump_request_list
|
||||||
self.dump_request_list = []
|
self.dump_request_list = []
|
||||||
|
|
||||||
def background_task():
|
def background_task():
|
||||||
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
||||||
current_time = datetime.now()
|
with open(filename, "wb") as f:
|
||||||
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
|
|
||||||
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
|
|
||||||
pickle.dump(to_dump, f)
|
pickle.dump(to_dump, f)
|
||||||
|
|
||||||
# Schedule the task to run in the background without awaiting it
|
# Schedule the task to run in the background without awaiting it
|
||||||
asyncio.create_task(asyncio.to_thread(background_task))
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
|
|
||||||
|
async def print_exception_wrapper(func):
|
||||||
|
"""
|
||||||
|
Sometimes an asyncio function does not print exception.
|
||||||
|
We do another wrapper to handle the exception.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await func()
|
||||||
|
except Exception:
|
||||||
|
traceback = get_exception_traceback()
|
||||||
|
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
||||||
|
kill_process_tree(os.getpid(), include_parent=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
def __init__(self, tokenizer_manager):
|
def __init__(self, tokenizer_manager):
|
||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
|
|||||||
@@ -135,9 +135,13 @@ async def health_generate(request: Request) -> Response:
|
|||||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||||
|
|
||||||
if tokenizer_manager.is_generation:
|
if tokenizer_manager.is_generation:
|
||||||
gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
|
gri = GenerateReqInput(
|
||||||
|
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
|
gri = EmbeddingReqInput(
|
||||||
|
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for _ in tokenizer_manager.generate_request(gri, request):
|
async for _ in tokenizer_manager.generate_request(gri, request):
|
||||||
|
|||||||
@@ -560,6 +560,7 @@ def run_bench_serving(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_prompts=num_prompts,
|
num_prompts=num_prompts,
|
||||||
sharegpt_output_len=None,
|
sharegpt_output_len=None,
|
||||||
|
sharegpt_context_len=None,
|
||||||
random_input_len=random_input_len,
|
random_input_len=random_input_len,
|
||||||
random_output_len=random_output_len,
|
random_output_len=random_output_len,
|
||||||
random_range_ratio=0.0,
|
random_range_ratio=0.0,
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class TestEpMoE(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.5
|
self.assertGreater(metrics["score"], 0.5)
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
def test_mgsm_en(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
@@ -56,7 +56,7 @@ class TestEpMoE(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.8
|
self.assertGreater(metrics["score"], 0.8)
|
||||||
|
|
||||||
|
|
||||||
class TestEpMoEFP8(unittest.TestCase):
|
class TestEpMoEFP8(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user