Move mem_fraction_static adjustment for multimodal models to server_args.py & Fix session control & Other cleanups (#7748)

This commit is contained in:
Lianmin Zheng
2025-07-04 16:33:33 -07:00
committed by GitHub
parent 975a5ec69c
commit 14229ccf8f
16 changed files with 339 additions and 137 deletions

View File

@@ -42,7 +42,7 @@ from sglang.srt.configs import (
)
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
from sglang.srt.utils import is_remote_url, lru_cache_frozenset
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
@@ -103,6 +103,7 @@ def get_hf_text_config(config: PretrainedConfig):
return config
@lru_cache_frozenset(maxsize=32)
def get_config(
model: str,
trust_remote_code: bool,

View File

@@ -46,11 +46,11 @@ _is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
logger = logging.getLogger(__name__)
if is_npu():
import torch_npu
logger = logging.getLogger(__name__)
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:

View File

@@ -39,6 +39,7 @@ class SessionParams:
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
drop_previous_output: Optional[bool] = None
AudioDataItem = Union[str, Dict]

View File

@@ -203,7 +203,7 @@ class MultimodalDataItem:
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
@@ -244,15 +244,16 @@ class MultimodalDataItem:
"""
from sglang.srt.managers.mm_utils import hash_feature
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
else:
self.hash = hash_feature(self.pixel_values)
if self.hash is None:
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
@@ -295,6 +296,13 @@ class MultimodalDataItem:
ret.validate()
return ret
def merge(self, other):
self.pixel_values += other.pixel_values
self.image_sizes += other.image_sizes
self.image_offsets += other.image_offsets
self.hash = hash((self.hash, other.hash))
self.set_pad_value()
@dataclasses.dataclass
class MultimodalInputs:

View File

@@ -1100,7 +1100,7 @@ class Scheduler(
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
req.finished_reason = FINISH_ABORT(
req.set_finish_with_abort(
f"Invalid request: session id {recv_req.session_params.id} does not exist"
)
self._add_request_to_queue(req)

View File

@@ -54,7 +54,7 @@ class SessionReqNode:
prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
ret += child._str_helper(prefix)
return ret
@@ -106,14 +106,22 @@ class Session:
last_req.origin_input_ids
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
)
if session_params.drop_previous_output:
input_ids = last_req.origin_input_ids[:]
if session_params.offset and session_params.offset != 0:
input_ids = input_ids[: session_params.offset] + req.input_ids
else:
input_ids += req.input_ids
input_ids_unpadded = (
last_req.origin_input_ids_unpadded
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
)
if session_params.drop_previous_output:
input_ids_unpadded = last_req.origin_input_ids_unpadded[:]
if session_params.offset and session_params.offset != 0:
input_ids_unpadded = (
input_ids_unpadded[: session_params.offset] + req.input_ids
@@ -138,10 +146,11 @@ class Session:
token_ids_logprob=req.token_ids_logprob,
)
if last_req is not None:
new_req.multimodal_inputs = last_req.mm_inputs
new_req.multimodal_inputs = last_req.multimodal_inputs
new_req.tokenizer = tokenizer
if abort:
new_req.to_abort = True
new_req.set_finish_with_abort("Invalid request session id")
else:
new_req_node = SessionReqNode(new_req, last_req_node)
self.req_nodes[req.rid] = new_req_node

View File

@@ -1148,6 +1148,7 @@ class TokenizerManager:
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 1:
@@ -1166,6 +1167,7 @@ class TokenizerManager:
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 2:

View File

@@ -24,6 +24,9 @@ class MultiModalCache:
self.current_size += data_size
return True
def has(self, mm_hash: int) -> bool:
return mm_hash in self.mm_cache
def get(self, mm_hash: int) -> torch.Tensor:
return self.mm_cache.get(mm_hash)

View File

@@ -451,11 +451,6 @@ class ModelRunner:
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
if self.is_multimodal:
self.mem_fraction_static *= 0.90
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
if not self.is_multimodal_chunked_prefill_supported:
server_args.chunked_prefill_size = -1
logger.info(

View File

@@ -11,8 +11,6 @@ from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size

View File

@@ -3,11 +3,9 @@ import math
import re
from typing import Dict, List, Union
import torch
from PIL import Image
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (

View File

@@ -319,6 +319,14 @@ class ServerArgs:
else:
self.mem_fraction_static = 0.88
# Lazy init to avoid circular import
from sglang.srt.configs.model_config import ModelConfig
# Multimodal models need more memory for the image processor
model_config = ModelConfig.from_server_args(self)
if model_config.is_multimodal:
self.mem_fraction_static *= 0.90
# Set chunked prefill size, which depends on the gpu memory capacity
if self.chunked_prefill_size is None:
if gpu_mem is not None:

View File

@@ -42,7 +42,7 @@ import threading
import time
import traceback
import warnings
from collections import defaultdict
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache
@@ -97,35 +97,6 @@ time_infos = {}
HIP_FP8_E4M3_FNUZ_MAX = 224.0
_warned_bool_env_var_keys = set()
def get_bool_env_var(name: str, default: str = "false") -> bool:
value = os.getenv(name, default)
value = value.lower()
truthy_values = ("true", "1")
falsy_values = ("false", "0")
if (value not in truthy_values) and (value not in falsy_values):
if value not in _warned_bool_env_var_keys:
logger.warning(
f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
)
_warned_bool_env_var_keys.add(value)
return value in truthy_values
def get_int_env_var(name: str, default: int = 0) -> int:
value = os.getenv(name)
if value is None or not value.strip():
return default
try:
return int(value)
except ValueError:
return default
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def is_hip() -> bool:
@@ -176,6 +147,82 @@ def is_cpu() -> bool:
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
def get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0)
def _check(cc_major):
if not is_cuda():
return False
return torch.cuda.get_device_capability()[0] == cc_major and tuple(
map(int, torch.version.cuda.split(".")[:2])
) >= (12, 3)
is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9)
def is_blackwell():
if not is_cuda():
return False
return torch.cuda.get_device_capability()[0] == 10
_warned_bool_env_var_keys = set()
def get_bool_env_var(name: str, default: str = "false") -> bool:
value = os.getenv(name, default)
value = value.lower()
truthy_values = ("true", "1")
falsy_values = ("false", "0")
if (value not in truthy_values) and (value not in falsy_values):
if value not in _warned_bool_env_var_keys:
logger.warning(
f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
)
_warned_bool_env_var_keys.add(value)
return value in truthy_values
def get_int_env_var(name: str, default: int = 0) -> int:
value = os.getenv(name)
if value is None or not value.strip():
return default
try:
return int(value)
except ValueError:
return default
def support_triton(backend: str) -> bool:
return backend not in ["torch_native", "intel_amx"]
try:
import sgl_kernel
is_intel_amx_backend_available = hasattr(
torch.ops.sgl_kernel, "convert_weight_packed"
)
except:
is_intel_amx_backend_available = False
def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
def use_intel_amx_backend(layer):
return getattr(layer, "use_intel_amx_backend", False)
def is_flashinfer_available():
"""
Check whether flashinfer is available.
@@ -503,6 +550,46 @@ def set_random_seed(seed: int) -> None:
torch.cuda.manual_seed_all(seed)
def find_process_using_port(port: int) -> Optional[psutil.Process]:
for conn in psutil.net_connections(kind="inet"):
if conn.laddr.port == port:
try:
return psutil.Process(conn.pid)
except psutil.NoSuchProcess:
# It could happen by race condition (the proc dies when psutil.Process is called).
pass
return None
def wait_port_available(
port: int, port_name: str, timeout_s: int = 30, raise_exception: bool = True
) -> bool:
for i in range(timeout_s):
if is_port_available(port):
return True
if i > 10 and i % 5 == 0:
process = find_process_using_port(port)
if process is None:
logger.warning(
f"The port {port} is in use, but we could not find the process that uses it."
)
pid = process.pid
error_message = f"{port_name} is used by a process already. {process.name()=}' {process.cmdline()=} {process.status()=} {pid=}"
logger.info(
f"port {port} is in use. Waiting for {i} seconds for {port_name} to be available. {error_message}"
)
time.sleep(0.1)
if raise_exception:
raise ValueError(
f"{port_name} at {port} is not available in {timeout_s} seconds. {error_message}"
)
return False
def is_port_available(port):
"""Return whether a port is available."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -517,6 +604,19 @@ def is_port_available(port):
return False
def get_free_port():
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def decode_video_base64(video_base64):
from PIL import Image
@@ -819,6 +919,7 @@ def maybe_set_triton_cache_manager() -> None:
class CustomCacheManager(FileCacheManager):
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
def __init__(self, key, override=False, dump=False):
from sglang.srt.distributed.parallel_state import get_tp_group
self.key = key
self.lock_path = None
@@ -836,7 +937,10 @@ class CustomCacheManager(FileCacheManager):
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
)
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
try:
self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
except:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
@@ -1939,12 +2043,6 @@ def rank0_log(msg: str):
logger.info(msg)
def get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0)
def launch_dummy_health_check_server(host, port):
import asyncio
@@ -2131,35 +2229,12 @@ def fast_topk(values, topk, dim):
return torch.topk(values, topk, dim=dim)
def _check(cc_major):
if not is_cuda():
return False
return torch.cuda.get_device_capability()[0] == cc_major and tuple(
map(int, torch.version.cuda.split(".")[:2])
) >= (12, 3)
is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9)
def is_blackwell():
if not is_cuda():
return False
return torch.cuda.get_device_capability()[0] == 10
def get_free_port():
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def bind_or_assign(target, source):
if target is not None:
target.copy_(source)
return target
else:
return source
def get_local_ip_auto() -> str:
@@ -2412,26 +2487,75 @@ def bind_or_assign(target, source):
return source
def support_triton(backend: str) -> bool:
return backend not in ["torch_native", "intel_amx", "ascend"]
def prepack_weight_if_needed(weight):
if weight.device != torch.device("cpu"):
return weight
if not cpu_has_amx_support():
return weight
return torch.ops.sgl_kernel.convert_weight_packed(weight)
try:
import sgl_kernel
# TODO: currently gemm kernel has the below requirements:
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def dim_is_supported(weight):
return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0
is_intel_amx_backend_available = hasattr(
torch.ops.sgl_kernel, "convert_weight_packed"
def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None:
# Pack weight for get better performance on CPU
devices = {getattr(module, weight_name).device for weight_name in weight_names}
assert len(devices) == 1, f"Expects all weights to be on the same device"
device = devices.pop()
if transpose_dims:
assert len(weight_names) == len(
transpose_dims
), "len(weight_names) should be equal to len(transpose_dims)"
for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if not dim_is_supported(weight_tensor):
logger.warning(
f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
f"{module} won't use intel amx backend."
)
module.use_intel_amx_backend = False
return
if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
packed_weight = torch.nn.Parameter(
prepack_weight_if_needed(weight_tensor),
requires_grad=False,
)
packed_weight.__dict__ = weight_tensor.__dict__
setattr(module, weight_name, packed_weight)
module.use_intel_amx_backend = (
device == torch.device("cpu") and cpu_has_amx_support()
)
except:
is_intel_amx_backend_available = False
if (
module.use_intel_amx_backend
and hasattr(module, "bias")
and module.bias is not None
):
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
def cpu_has_amx_support():
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
class PackWeightMethod:
def __init__(self, weight_names, transpose_dims=None):
self.weight_names = weight_names
self.transpose_dims = transpose_dims
def use_intel_amx_backend(layer):
return getattr(layer, "use_intel_amx_backend", False)
def process_weights_after_loading(self, module) -> None:
_process_weight_after_loading(module, self.weight_names, self.transpose_dims)
class LazyValue:
@@ -2568,3 +2692,48 @@ def is_shm_available(dtype, world_size, local_size):
and world_size >= 1
and world_size == local_size
)
def lru_cache_frozenset(maxsize=128):
def _to_hashable(o):
try:
hash(o)
return o
except TypeError:
# Not hashable; convert based on type
if isinstance(o, (dict)):
return frozenset(
(_to_hashable(k), _to_hashable(v)) for k, v in o.items()
)
elif isinstance(o, set):
return frozenset(_to_hashable(v) for v in o)
elif isinstance(o, (list, tuple)) or (
isinstance(o, Sequence) and not isinstance(o, (str, bytes))
):
return tuple(_to_hashable(v) for v in o)
else:
raise TypeError(f"Cannot make hashable: {type(o)}")
def decorator(func):
cache = OrderedDict()
@functools.wraps(func)
def wrapper(*args, **kwargs):
h_args = tuple(_to_hashable(a) for a in args)
h_kwargs = frozenset(
(_to_hashable(k), _to_hashable(v)) for k, v in kwargs.items()
)
key = (h_args, h_kwargs)
if key in cache:
cache.move_to_end(key)
return cache[key]
result = func(*args, **kwargs)
cache[key] = result
if maxsize is not None and len(cache) > maxsize:
cache.popitem(last=False)
return result
wrapper.cache_clear = cache.clear # For manual cache clearing
return wrapper
return decorator