Improve profiler and integrate profiler in bench_one_batch_server (#6787)

This commit is contained in:
Lianmin Zheng
2025-05-31 15:53:55 -07:00
committed by GitHub
parent b520d02888
commit 2d72fc47cf
25 changed files with 481 additions and 223 deletions

View File

@@ -8,6 +8,7 @@ Usage:
python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
"""
import argparse
@@ -19,10 +20,10 @@ import os
import time
from typing import Tuple
import numpy as np
import requests
from sglang.bench_serving import get_tokenizer, sample_random_requests
from sglang.profiler import run_profile
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
@@ -42,6 +43,8 @@ class BenchArgs:
base_url: str = ""
skip_warmup: bool = False
show_report: bool = False
profile: bool = False
profile_by_stage: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -68,6 +71,8 @@ class BenchArgs:
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile-by-stage", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
@@ -93,8 +98,8 @@ def launch_server_process(server_args: ServerArgs):
base_url = f"http://{server_args.host}:{server_args.port}"
timeout = 600
start_time = time.perf_counter()
while time.perf_counter() - start_time < timeout:
start_time = time.time()
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
@@ -119,6 +124,8 @@ def run_one_case(
run_name: str,
result_filename: str,
tokenizer,
profile: bool = False,
profile_by_stage: bool = False,
):
requests.post(url + "/flush_cache")
input_requests = sample_random_requests(
@@ -145,6 +152,12 @@ def run_one_case(
else:
json_schema = None
profile_link = None
if profile:
profile_link: str = run_profile(
url, 3, ["CPU", "GPU"], None, None, profile_by_stage
)
tic = time.perf_counter()
response = requests.post(
url + "/generate",
@@ -194,8 +207,8 @@ def run_one_case(
print(f"output_len: {output_len}")
print(f"latency: {latency:.2f} s")
print(f"ttft: {ttft:.2f} s")
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"Input throughput: {input_throughput:.2f} tok/s")
print(f"last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"input throughput: {input_throughput:.2f} tok/s")
if output_len != 1:
print(f"output throughput: {output_throughput:.2f} tok/s")
@@ -222,6 +235,7 @@ def run_one_case(
overall_throughput,
last_gen_throughput,
acc_length,
profile_link if profile else None,
)
@@ -253,6 +267,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
# benchmark
result = []
bench_result = []
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
@@ -271,6 +286,33 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
tokenizer=tokenizer,
)
)
if bench_args.profile:
try:
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
bench_result.append(
(
run_one_case(
base_url,
bs,
il,
ol,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
profile=bench_args.profile,
profile_by_stage=bench_args.profile_by_stage,
)[-1],
)
)
result = [t1[:-1] + t2 for t1, t2 in zip(result, bench_result)]
except Exception as e:
print(f"Error profiling, there will be no profile trace dump: {e}")
finally:
if proc:
kill_process_tree(proc.pid)
@@ -280,8 +322,20 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if not bench_args.show_report:
return
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
summary = (
f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n"
)
summary += "| batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |"
if bench_args.profile:
summary += " profile |"
summary += "\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ |"
if bench_args.profile:
summary += "-------------|"
summary += "\n"
for (
batch_size,
@@ -292,6 +346,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
overall_throughput,
last_gen_throughput,
acc_length,
trace_link,
) in result:
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
input_util = 0.7
@@ -304,17 +359,18 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
f"{accept_length} | "
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |"
)
if trace_link:
line += f" [Profile]({trace_link}) |"
line += "\n"
summary += line
# print metrics table
print(summary)
if is_in_ci():
write_github_step_summary(
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
)
write_github_step_summary(summary)
if __name__ == "__main__":

167
python/sglang/profiler.py Normal file
View File

@@ -0,0 +1,167 @@
"""
Run live profiling.
Usage:
python3 -m sglang.profiler
"""
import argparse
import json
import os
import time
import urllib.parse
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional
import requests
PARENT_FOLDER = "/tmp/sglang-profile"
def _run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
) -> str:
if output_dir is None:
output_dir = PARENT_FOLDER
output_dir = os.path.normpath(output_dir)
output_dir = os.path.abspath(output_dir)
output_dir = Path(output_dir)
# Add "profile_name/timestamp" to the path.
if profile_name:
output_dir = output_dir / profile_name
output_dir = output_dir / str(time.time())
output_dir.mkdir(exist_ok=True, parents=True)
print(f"Dump profiling traces to {output_dir}")
print(
f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})"
)
# Dump server args.
file_path = Path(output_dir) / "server_args.json"
if not file_path.exists():
response = requests.get(url + "/get_server_info")
response.raise_for_status()
server_args_data = response.json()
with open(file_path, "w") as file:
file.write(json.dumps(server_args_data))
# Start profiler. The API replies when all steps are processed
# and files are generated.
json_data = {
"output_dir": str(output_dir),
"num_steps": str(num_steps),
"activities": activities,
"profile_by_stage": profile_by_stage,
}
response = requests.post(url=url + "/start_profile", json=json_data)
response.raise_for_status()
trace_link = str(output_dir)
return trace_link
def run_profile(
url: Optional[str],
num_steps: int,
activities: List[str],
output_dir: Optional[str] = None,
profile_name: Optional[str] = None,
profile_by_stage: bool = False,
):
# step based profile will self terminate on num_steps constraints
link = _run_profile(
url, num_steps, activities, output_dir, profile_name, profile_by_stage
)
return link
if __name__ == "__main__":
parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--url",
type=str,
default="http://localhost:30000",
help="Server or API base url if not using http host and port.",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Profile directory to dump profile traces.",
)
parser.add_argument(
"--profile-name",
type=str,
default=None,
help="The name of this profile run.",
)
parser.add_argument(
"--num-steps",
type=int,
default=5,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--profile-by-stage",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="The number of forward steps to profile.",
)
parser.add_argument(
"--cpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile CPU activity",
)
parser.add_argument(
"--gpu",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
help="Whether to profile GPU activity",
)
parser.add_argument(
"--mem",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to memory usage (https://pytorch.org/memory_viz)",
)
parser.add_argument(
"--rpd",
action=argparse.BooleanOptionalAction,
type=bool,
default=False,
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
)
args = parser.parse_args()
activities = []
if args.cpu:
activities.append("CPU")
if args.gpu:
activities.append("GPU")
if args.mem:
activities.append("MEM")
if args.rpd:
activities.append("RPD")
run_profile(
args.url,
args.num_steps,
activities,
args.output_dir,
args.profile_name,
args.profile_by_stage,
)

View File

@@ -514,9 +514,7 @@ def _set_envs_and_config(server_args: ServerArgs):
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
logger.warning(
"Child process unexpectedly failed with an exit code %d. pid=%d",
exitcode,
pid,
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
)
signal.signal(signal.SIGCHLD, sigchld_handler)

View File

@@ -350,6 +350,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
activities=obj.activities,
with_stack=obj.with_stack,
record_shapes=obj.record_shapes,
profile_by_stage=obj.profile_by_stage,
)
return Response(
content="Start profiling.\n",

View File

@@ -401,7 +401,6 @@ def compute_initial_expert_location_metadata(
) -> ExpertLocationMetadata:
data = server_args.init_expert_location
if data == "trivial":
logger.info("init_expert_location from trivial")
return ExpertLocationMetadata.init_trivial(server_args, model_config)
# TODO unify with the utils function

View File

@@ -848,7 +848,8 @@ class ProfileReqInput:
# If it is set, profiling is automatically stopped after this step, and
# the caller doesn't need to run stop_profile.
num_steps: Optional[int] = None
activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
activities: Optional[List[str]] = None
profile_by_stage: bool = False
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
@@ -875,6 +876,7 @@ class ProfileReq:
output_dir: Optional[str] = None
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
profile_by_stage: bool = False
with_stack: Optional[bool] = None
record_shapes: Optional[bool] = None
profile_id: Optional[str] = None

View File

@@ -34,7 +34,6 @@ import zmq
from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt import two_batch_overlap
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
@@ -63,7 +62,6 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.io_struct import (
@@ -140,6 +138,7 @@ from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
disable_request_logging,
get_available_gpu_memory,
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
@@ -213,7 +212,6 @@ class Scheduler(
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
@@ -333,12 +331,16 @@ class Scheduler(
# Print debug info
if tp_rank == 0:
avail_mem = get_available_gpu_memory(
self.device, self.gpu_id, empty_cache=False
)
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
f"context_len={self.model_config.context_len}, "
f"available_gpu_mem={avail_mem:.2f} GB"
)
# Init memory pool and cache
@@ -362,6 +364,7 @@ class Scheduler(
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
self.forward_sleep_time = None
# Init session info
self.sessions: Dict[str, Session] = {}
@@ -425,8 +428,14 @@ class Scheduler(
self.profiler_activities: Optional[List[str]] = None
self.profiler_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None
self.forward_sleep_time = None
self.profiler_target_prefill_ct: Optional[int] = None
self.profiler_target_decode_ct: Optional[int] = None
self.profiler_prefill_ct: Optional[int] = None
self.profiler_decode_ct: Optional[int] = None
self.profile_by_stage: bool = False
self.profile_steps: Optional[int] = None
self.profile_in_progress: bool = False
self.rpd_profiler = None
# Init metrics stats
self.init_metrics()
@@ -1518,7 +1527,7 @@ class Scheduler(
self.new_token_ratio = new_token_ratio
logger.info(
"Decode out of memory happened. "
"KV cache pool is full. Retract requests. "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
@@ -1542,13 +1551,8 @@ class Scheduler(
"""Run a batch."""
self.forward_ct += 1
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.send_to_tokenizer.send_pyobj(self.stop_profile())
# Whether to run the profiler
self._profile_batch_predicate(batch)
if self.forward_sleep_time is not None:
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
time.sleep(self.forward_sleep_time)
@@ -2121,46 +2125,82 @@ class Scheduler(
def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE:
return self.start_profile(
recv_req.output_dir,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_id,
)
if recv_req.profile_by_stage:
return self.init_profile(
recv_req.output_dir,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
)
else:
self.init_profile(
recv_req.output_dir,
recv_req.num_steps,
recv_req.activities,
recv_req.with_stack,
recv_req.record_shapes,
recv_req.profile_by_stage,
)
return self.start_profile(True)
else:
return self.stop_profile()
def start_profile(
def init_profile(
self,
output_dir: Optional[str],
num_steps: Optional[int],
activities: Optional[List[str]],
with_stack: Optional[bool],
record_shapes: Optional[bool],
profile_id: Optional[str],
) -> None:
if self.profiler_activities:
profile_by_stage: bool,
) -> ProfileReqOutput:
if self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is already in progress. Call /stop_profile first.",
)
self.profile_by_stage = profile_by_stage
if output_dir is None:
output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
if activities is None:
activities = ["CPU", "GPU"]
self.torch_profiler_output_dir = output_dir
self.torch_profiler_with_stack = with_stack
self.torch_profiler_record_shapes = record_shapes
self.profiler_activities = activities
self.profiler_id = profile_id
if num_steps:
self.profile_steps = num_steps
if self.profile_by_stage:
self.profiler_target_prefill_ct = num_steps
self.profiler_target_decode_ct = num_steps
self.profiler_prefill_ct = 0
self.profiler_decode_ct = 0
else:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
"Profiling starts. Traces will be saved to: %s (with id %s)",
self.torch_profiler_output_dir,
self.profiler_id,
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir}",
)
activities = self.profiler_activities
with_stack = self.torch_profiler_with_stack
record_shapes = self.torch_profiler_record_shapes
activity_map = {
"CPU": torch.profiler.ProfilerActivity.CPU,
"GPU": torch.profiler.ProfilerActivity.CUDA,
@@ -2169,48 +2209,97 @@ class Scheduler(
activity_map[a] for a in activities if a in activity_map
]
if torchprof_activities:
if "RPD" in activities:
from rpdTracerControl import rpdTracerControl
rpdTracerControl.skipCreate()
self.rpd_profile_path = os.path.join(
self.torch_profiler_output_dir,
"rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
)
if self.tp_rank == 0:
import sqlite3
from rocpd.schema import RocpdSchema
if os.path.exists("trace.rpd"):
os.unlink("trace.rpd")
schema = RocpdSchema()
connection = sqlite3.connect("trace.rpd")
schema.writeSchema(connection)
connection.commit()
del connection
torch.distributed.barrier(self.tp_cpu_group)
self.rpd_profiler = rpdTracerControl()
self.rpd_profiler.setPythonTrace(True)
self.rpd_profiler.start()
self.rpd_profiler.rangePush("", "rpd profile range", "")
self.profile_in_progress = True
elif torchprof_activities:
self.torch_profiler = torch.profiler.profile(
activities=torchprof_activities,
with_stack=with_stack if with_stack is not None else True,
record_shapes=record_shapes if record_shapes is not None else False,
)
self.torch_profiler.start()
self.profile_in_progress = True
if "MEM" in activities:
torch.cuda.memory._record_memory_history(max_entries=100000)
self.profile_in_progress = True
if "CUDA_PROFILER" in activities:
torch.cuda.cudart().cudaProfilerStart()
if num_steps:
self.profiler_target_forward_ct = self.forward_ct + num_steps
# The caller will be notified when reaching profiler_target_forward_ct
else:
self.profiler_target_forward_ct = None
return ProfileReqOutput(success=True, message="Succeeded")
return ProfileReqOutput(success=True, message="Succeeded")
def stop_profile(self) -> None:
if self.profiler_activities is None:
def stop_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
if not self.profile_in_progress:
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
logger.info("Stop profiling...")
stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
self.torch_profiler.export_chrome_trace(
os.path.join(
self.torch_profiler_output_dir,
self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
str(time.time())
+ f"-TP-{self.tp_rank}"
+ stage_suffix
+ ".trace.json.gz",
)
)
torch.distributed.barrier(self.tp_cpu_group)
if "MEM" in self.profiler_activities:
if self.rpd_profiler is not None:
self.rpd_profiler.rangePop()
self.rpd_profiler.stop()
self.rpd_profiler.flush()
torch.distributed.barrier(self.tp_cpu_group)
if self.tp_rank == 0:
from sglang.srt.utils import rpd_to_chrome_trace
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
self.rpd_profiler = None
self.rpd_profiler_path = None
if self.profiler_activities is not None and "MEM" in self.profiler_activities:
memory_profile_path = os.path.join(
self.torch_profiler_output_dir,
self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
str(time.time())
+ f"-TP-{self.tp_rank}-memory"
+ stage_suffix
+ ".pickle",
)
torch.cuda.memory._dump_snapshot(memory_profile_path)
torch.cuda.memory._record_memory_history(enabled=None)
@@ -2223,11 +2312,38 @@ class Scheduler(
self.torch_profiler_output_dir,
)
self.torch_profiler = None
self.torch_profiler_output_dir = None
self.profiler_activities = None
self.profiler_target_forward_ct = None
self.profile_in_progress = False
return ProfileReqOutput(success=True, message="Succeeded")
return ProfileReqOutput(success=True, message="Succeeded.")
def _profile_batch_predicate(self, batch):
if self.profile_by_stage:
if batch.forward_mode.is_prefill():
if self.profiler_prefill_ct == 0:
self.start_profile(batch.forward_mode)
self.profiler_prefill_ct += 1
if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.EXTEND)
elif batch.forward_mode.is_decode():
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
if self.profile_in_progress:
self.stop_profile(stage=ForwardMode.DECODE)
else:
raise RuntimeError("unsupported profile stage")
else:
# Check profiler
if (
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:

View File

@@ -796,6 +796,7 @@ class TokenizerManager:
activities: Optional[List[str]] = None,
with_stack: Optional[bool] = None,
record_shapes: Optional[bool] = None,
profile_by_stage: bool = False,
):
self.auto_create_handle_loop()
req = ProfileReq(
@@ -805,6 +806,7 @@ class TokenizerManager:
activities=activities,
with_stack=with_stack,
record_shapes=record_shapes,
profile_by_stage=profile_by_stage,
profile_id=str(time.time()),
)
return await self._execute_profile(req)

View File

@@ -39,10 +39,7 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import (
TboCudaGraphRunnerPlugin,
TboForwardBatchPreparer,
)
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
from sglang.srt.utils import (
get_available_gpu_memory,
get_device_memory_capacity,

View File

@@ -77,11 +77,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.patch_torch import monkey_patch_torch_reductions

View File

@@ -1643,7 +1643,7 @@ def auto_choose_speculative_params(arch: str):
return (5, 4, 8)
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
# The default value for deepseek
return (5, 4, 8)
return (3, 1, 4)
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
return (5, 4, 8)
else:

View File

@@ -93,6 +93,11 @@ def is_in_ci():
return get_bool_env_var("SGLANG_IS_IN_CI")
def is_in_amd_ci():
"""Return whether it is in an AMD CI runner."""
return get_bool_env_var("SGLANG_AMD_CI")
if is_in_ci():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
5000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100