diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b64c05580..92cf0388e 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -37,7 +37,7 @@ from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.utils import ( get_available_gpu_memory, - get_whatever_gpu_memory_capacity, + get_device_memory_capacity, is_hip, ) @@ -133,14 +133,10 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16)) ) - if _is_hip: + gpu_mem = get_device_memory_capacity() + if gpu_mem is not None and gpu_mem > 81920: capture_bs += list(range(160, 257, 8)) - gpu_mem = get_whatever_gpu_memory_capacity() / 1024 - - if gpu_mem is not None and gpu_mem > 120: - capture_bs += list(range(160, 256, 8)) - if max(capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # is very small. We add more values here to make sure we capture the maximum bs. @@ -152,10 +148,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): assert len(capture_bs) > 0 and capture_bs[0] > 0 capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] - if server_args.cuda_graph_max_bs: capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] - compile_bs = ( [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] if server_args.enable_torch_compile diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ed0c6bf51..4c2b0122f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -27,7 +27,7 @@ from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( configure_ipv6, get_device, - get_whatever_gpu_memory_capacity, + get_device_memory_capacity, is_flashinfer_available, is_hip, is_port_available, @@ -218,7 +218,7 @@ class ServerArgs: if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) - gpu_mem = get_whatever_gpu_memory_capacity(self.device) + gpu_mem = get_device_memory_capacity(self.device) # Set mem fraction static, which depends on the tensor parallelism size if self.mem_fraction_static is None: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 69f0c4ff0..1e9e66441 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1170,7 +1170,7 @@ def get_hpu_memory_capacity(): ) -def get_whatever_gpu_memory_capacity(device: str = None): +def get_device_memory_capacity(device: str = None): if is_cuda(): gpu_mem = get_nvgpu_memory_capacity() elif is_hip():