Add get_amdgpu_memory_capacity() (#2049)
This commit is contained in:
@@ -23,8 +23,10 @@ import tempfile
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_gpu_memory_capacity,
|
get_amdgpu_memory_capacity,
|
||||||
|
get_nvgpu_memory_capacity,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
|
is_hip,
|
||||||
is_ipv6,
|
is_ipv6,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
)
|
)
|
||||||
@@ -165,7 +167,10 @@ class ServerArgs:
|
|||||||
self.mem_fraction_static = 0.88
|
self.mem_fraction_static = 0.88
|
||||||
|
|
||||||
# Adjust for GPUs with small memory capacities
|
# Adjust for GPUs with small memory capacities
|
||||||
gpu_mem = get_gpu_memory_capacity()
|
if is_hip():
|
||||||
|
gpu_mem = get_amdgpu_memory_capacity()
|
||||||
|
else:
|
||||||
|
gpu_mem = get_nvgpu_memory_capacity()
|
||||||
if gpu_mem < 25000:
|
if gpu_mem < 25000:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Automatically adjust --chunked-prefill-size for small GPUs."
|
"Automatically adjust --chunked-prefill-size for small GPUs."
|
||||||
|
|||||||
@@ -794,7 +794,39 @@ def add_prometheus_middleware(app):
|
|||||||
app.routes.append(metrics_route)
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory_capacity():
|
def get_amdgpu_memory_capacity():
|
||||||
|
try:
|
||||||
|
# Run rocm-smi and capture the output
|
||||||
|
result = subprocess.run(
|
||||||
|
["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
shell=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")
|
||||||
|
|
||||||
|
# Parse the output to extract memory values in MiB
|
||||||
|
memory_values = [
|
||||||
|
float(mem) / 1024 / 1024
|
||||||
|
for mem in result.stdout.strip().split("\n")
|
||||||
|
if re.match(r"^\d+(\.\d+)?$", mem.strip())
|
||||||
|
]
|
||||||
|
|
||||||
|
if not memory_values:
|
||||||
|
raise ValueError("No GPU memory values found.")
|
||||||
|
|
||||||
|
# Return the minimum memory value
|
||||||
|
return min(memory_values)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nvgpu_memory_capacity():
|
||||||
try:
|
try:
|
||||||
# Run nvidia-smi and capture the output
|
# Run nvidia-smi and capture the output
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
|
|||||||
Reference in New Issue
Block a user