[Fix] Adjust default chunked prefill size and cuda graph max bs according to GPU memory capacity (#2044)
This commit is contained in:
@@ -22,7 +22,12 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
|
from sglang.srt.utils import (
|
||||||
|
get_gpu_memory_capacity,
|
||||||
|
is_flashinfer_available,
|
||||||
|
is_ipv6,
|
||||||
|
is_port_available,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -143,6 +148,9 @@ class ServerArgs:
|
|||||||
# Disable chunked prefill
|
# Disable chunked prefill
|
||||||
self.chunked_prefill_size = None
|
self.chunked_prefill_size = None
|
||||||
|
|
||||||
|
if self.random_seed is None:
|
||||||
|
self.random_seed = random.randint(0, 1 << 30)
|
||||||
|
|
||||||
# Mem fraction depends on the tensor parallelism size
|
# Mem fraction depends on the tensor parallelism size
|
||||||
if self.mem_fraction_static is None:
|
if self.mem_fraction_static is None:
|
||||||
if self.tp_size >= 16:
|
if self.tp_size >= 16:
|
||||||
@@ -156,8 +164,14 @@ class ServerArgs:
|
|||||||
else:
|
else:
|
||||||
self.mem_fraction_static = 0.88
|
self.mem_fraction_static = 0.88
|
||||||
|
|
||||||
if self.random_seed is None:
|
# Adjust for GPUs with small memory capacities
|
||||||
self.random_seed = random.randint(0, 1 << 30)
|
gpu_mem = get_gpu_memory_capacity()
|
||||||
|
if gpu_mem < 25000:
|
||||||
|
logger.warning(
|
||||||
|
"Automatically adjust --chunked-prefill-size for small GPUs."
|
||||||
|
)
|
||||||
|
self.chunked_prefill_size //= 4 # make it 2048
|
||||||
|
self.cuda_graph_max_bs = 4
|
||||||
|
|
||||||
# Deprecation warnings
|
# Deprecation warnings
|
||||||
if self.disable_flashinfer:
|
if self.disable_flashinfer:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import resource
|
|||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import socket
|
import socket
|
||||||
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
@@ -791,3 +792,35 @@ def add_prometheus_middleware(app):
|
|||||||
# Workaround for 307 Redirect for /metrics
|
# Workaround for 307 Redirect for /metrics
|
||||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||||
app.routes.append(metrics_route)
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_memory_capacity():
|
||||||
|
try:
|
||||||
|
# Run nvidia-smi and capture the output
|
||||||
|
result = subprocess.run(
|
||||||
|
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")
|
||||||
|
|
||||||
|
# Parse the output to extract memory values
|
||||||
|
memory_values = [
|
||||||
|
float(mem)
|
||||||
|
for mem in result.stdout.strip().split("\n")
|
||||||
|
if re.match(r"^\d+(\.\d+)?$", mem.strip())
|
||||||
|
]
|
||||||
|
|
||||||
|
if not memory_values:
|
||||||
|
raise ValueError("No GPU memory values found.")
|
||||||
|
|
||||||
|
# Return the minimum memory value
|
||||||
|
return min(memory_values)
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user