Move mem_fraction_static adjustment for multimodal models to server_args.py & Fix session control & Other cleanups (#7748)

This commit is contained in:
Lianmin Zheng
2025-07-04 16:33:33 -07:00
committed by GitHub
parent 975a5ec69c
commit 14229ccf8f
16 changed files with 339 additions and 137 deletions

View File

@@ -39,6 +39,7 @@ class SessionParams:
rid: Optional[str] = None
offset: Optional[int] = None
replace: Optional[bool] = None
drop_previous_output: Optional[bool] = None
AudioDataItem = Union[str, Dict]

View File

@@ -203,7 +203,7 @@ class MultimodalDataItem:
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None
pixel_values: Union[torch.Tensor, np.ndarray, "PIL.Image"] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
@@ -244,15 +244,16 @@ class MultimodalDataItem:
"""
from sglang.srt.managers.mm_utils import hash_feature
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
else:
self.hash = hash_feature(self.pixel_values)
if self.hash is None:
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
@@ -295,6 +296,13 @@ class MultimodalDataItem:
ret.validate()
return ret
def merge(self, other):
self.pixel_values += other.pixel_values
self.image_sizes += other.image_sizes
self.image_offsets += other.image_offsets
self.hash = hash((self.hash, other.hash))
self.set_pad_value()
@dataclasses.dataclass
class MultimodalInputs:

View File

@@ -1100,7 +1100,7 @@ class Scheduler(
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
req.finished_reason = FINISH_ABORT(
req.set_finish_with_abort(
f"Invalid request: session id {recv_req.session_params.id} does not exist"
)
self._add_request_to_queue(req)

View File

@@ -54,7 +54,7 @@ class SessionReqNode:
prefix += " -- " + self.childs[0].req.rid
ret = self.childs[0]._str_helper(prefix)
for child in self.childs[1:]:
prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
ret += child._str_helper(prefix)
return ret
@@ -106,14 +106,22 @@ class Session:
last_req.origin_input_ids
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
)
if session_params.drop_previous_output:
input_ids = last_req.origin_input_ids[:]
if session_params.offset and session_params.offset != 0:
input_ids = input_ids[: session_params.offset] + req.input_ids
else:
input_ids += req.input_ids
input_ids_unpadded = (
last_req.origin_input_ids_unpadded
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
)
if session_params.drop_previous_output:
input_ids_unpadded = last_req.origin_input_ids_unpadded[:]
if session_params.offset and session_params.offset != 0:
input_ids_unpadded = (
input_ids_unpadded[: session_params.offset] + req.input_ids
@@ -138,10 +146,11 @@ class Session:
token_ids_logprob=req.token_ids_logprob,
)
if last_req is not None:
new_req.multimodal_inputs = last_req.mm_inputs
new_req.multimodal_inputs = last_req.multimodal_inputs
new_req.tokenizer = tokenizer
if abort:
new_req.to_abort = True
new_req.set_finish_with_abort("Invalid request session id")
else:
new_req_node = SessionReqNode(new_req, last_req_node)
self.req_nodes[req.rid] = new_req_node

View File

@@ -1148,6 +1148,7 @@ class TokenizerManager:
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 1:
@@ -1166,6 +1167,7 @@ class TokenizerManager:
[
"text",
"output_ids",
"embedding",
]
)
elif self.log_requests_level == 2: