diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1531de09f..5aac2c581 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -23,8 +23,10 @@ import tempfile from typing import List, Optional from sglang.srt.utils import ( - get_gpu_memory_capacity, + get_amdgpu_memory_capacity, + get_nvgpu_memory_capacity, is_flashinfer_available, + is_hip, is_ipv6, is_port_available, ) @@ -165,7 +167,10 @@ class ServerArgs: self.mem_fraction_static = 0.88 # Adjust for GPUs with small memory capacities - gpu_mem = get_gpu_memory_capacity() + if is_hip(): + gpu_mem = get_amdgpu_memory_capacity() + else: + gpu_mem = get_nvgpu_memory_capacity() if gpu_mem < 25000: logger.warning( "Automatically adjust --chunked-prefill-size for small GPUs." diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f7e32e653..e04ec7ddf 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -794,7 +794,39 @@ def add_prometheus_middleware(app): app.routes.append(metrics_route) -def get_gpu_memory_capacity(): +def get_amdgpu_memory_capacity(): + try: + # Run rocm-smi and capture the output + result = subprocess.run( + ["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}") + + # Parse the output to extract memory values in MiB + memory_values = [ + float(mem) / 1024 / 1024 + for mem in result.stdout.strip().split("\n") + if re.match(r"^\d+(\.\d+)?$", mem.strip()) + ] + + if not memory_values: + raise ValueError("No GPU memory values found.") + + # Return the minimum memory value + return min(memory_values) + + except FileNotFoundError: + raise RuntimeError( + "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible." + ) + + +def get_nvgpu_memory_capacity(): try: # Run nvidia-smi and capture the output result = subprocess.run(