Add DeepSeek V3/R1 shared experts fusion (#4918)
This commit is contained in:
@@ -993,13 +993,16 @@ async def benchmark(
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
|
||||
# Warmup
|
||||
print("Starting initial single prompt test run...")
|
||||
print(f"Starting warmup with {args.warmup_requests} sequences...")
|
||||
|
||||
# Use the first request for all warmup iterations
|
||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||
if lora_names != None and len(lora_names) != 0:
|
||||
lora_name = lora_names[0]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
# Create the test input once
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
@@ -1009,14 +1012,26 @@ async def benchmark(
|
||||
lora_name=lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
|
||||
# Run warmup requests
|
||||
warmup_tasks = []
|
||||
for _ in range(args.warmup_requests):
|
||||
warmup_tasks.append(
|
||||
asyncio.create_task(request_func(request_func_input=test_input))
|
||||
)
|
||||
|
||||
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
||||
|
||||
# Check if at least one warmup request succeeded
|
||||
if not any(output.success for output in warmup_outputs):
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}"
|
||||
"Warmup failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
||||
)
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
print(
|
||||
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
|
||||
)
|
||||
|
||||
# Flush cache
|
||||
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
|
||||
@@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if not hasattr(args, "max_concurrency"):
|
||||
args.max_concurrency = None
|
||||
|
||||
# Set default value for warmup_requests if not present
|
||||
if not hasattr(args, "warmup_requests"):
|
||||
args.warmup_requests = 1
|
||||
|
||||
print(f"benchmark_args={args}")
|
||||
|
||||
# Set global environments
|
||||
@@ -1560,6 +1579,12 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Flush the cache before running the benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-requests",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup requests to run before the benchmark",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user