From 78f139812a46c4b85dcf948663fd4f11230d6f43 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 8 Sep 2025 16:27:37 +0800 Subject: [PATCH] [1/N] DP-Refactor: move communicators into `tokenizer_communicator_mixin` (#10028) --- .../srt/managers/multi_tokenizer_mixin.py | 3 +- .../managers/tokenizer_communicator_mixin.py | 491 ++++++++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 462 +--------------- python/sglang/srt/weight_sync/utils.py | 2 +- python/sglang/utils.py | 4 + 5 files changed, 503 insertions(+), 459 deletions(-) create mode 100644 python/sglang/srt/managers/tokenizer_communicator_mixin.py diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 621989e03..e4f83c82b 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import ( MultiTokenizerRegisterReq, MultiTokenizerWrapper, ) -from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator +from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator +from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.utils import get_exception_traceback diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py new file mode 100644 index 000000000..e59d3f296 --- /dev/null +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -0,0 +1,491 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import time +from collections import deque +from typing import ( + TYPE_CHECKING, + Any, + Deque, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, +) + +import fastapi + +from sglang.srt.managers.io_struct import ( + ClearHiCacheReqInput, + ClearHiCacheReqOutput, + ExpertDistributionReq, + ExpertDistributionReqOutput, + FlushCacheReqInput, + FlushCacheReqOutput, + GetInternalStateReq, + GetInternalStateReqOutput, + GetWeightsByNameReqInput, + GetWeightsByNameReqOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, + LoadLoRAAdapterReqInput, + LoadLoRAAdapterReqOutput, + LoRAUpdateResult, + MultiTokenizerWrapper, + ProfileReq, + ProfileReqOutput, + ProfileReqType, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, + SetInternalStateReq, + SetInternalStateReqOutput, + SlowDownReqInput, + SlowDownReqOutput, + UnloadLoRAAdapterReqInput, + UnloadLoRAAdapterReqOutput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, +) +from sglang.srt.server_args import LoRARef, ServerArgs +from sglang.srt.utils import get_bool_env_var +from sglang.utils import TypeBasedDispatcher + +if TYPE_CHECKING: + from sglang.srt.managers.tokenizer_manager import TokenizerManager + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + + +class _Communicator(Generic[T]): + """Note: The communicator now only run up to 1 in-flight request at any time.""" + + enable_multi_tokenizer = False + + def __init__(self, sender, fan_out: int): + self._sender = sender + self._fan_out = fan_out + self._result_event: Optional[asyncio.Event] = None + self._result_values: Optional[List[T]] = None + self._ready_queue: Deque[asyncio.Future] = deque() + + async def __call__(self, obj): + ready_event = asyncio.Event() + if self._result_event is not None or len(self._ready_queue) > 0: + self._ready_queue.append(ready_event) + await ready_event.wait() + assert self._result_event is None + assert self._result_values is None + + if obj: + if _Communicator.enable_multi_tokenizer: + obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj) + self._sender.send_pyobj(obj) + + self._result_event = asyncio.Event() + self._result_values = [] + await self._result_event.wait() + result_values = self._result_values + self._result_event = self._result_values = None + + if len(self._ready_queue) > 0: + self._ready_queue.popleft().set() + + return result_values + + def handle_recv(self, recv_obj: T): + self._result_values.append(recv_obj) + if len(self._result_values) == self._fan_out: + self._result_event.set() + + +class TokenizerCommunicatorMixin: + """Mixin class for TokenizerManager to handle communication with the scheduler.""" + + def init_communicators(self: TokenizerManager, server_args: ServerArgs): + # Communicators + self.init_weights_update_group_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_weights_from_distributed_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_weights_from_tensor_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.get_weights_by_name_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.release_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.resume_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.slow_down_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.flush_cache_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.clear_hicache_storage_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.profile_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.get_internal_state_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.set_internal_state_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.expert_distribution_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_lora_adapter_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + + self._result_dispatcher += self._get_communicator_dispatcher() + + def _get_communicator_dispatcher(self: TokenizerManager): + return TypeBasedDispatcher( + [ + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, + ), + ( + UpdateWeightsFromDistributedReqOutput, + self.update_weights_from_distributed_communicator.handle_recv, + ), + ( + UpdateWeightsFromTensorReqOutput, + self.update_weights_from_tensor_communicator.handle_recv, + ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, + ), + ( + ReleaseMemoryOccupationReqOutput, + self.release_memory_occupation_communicator.handle_recv, + ), + ( + ResumeMemoryOccupationReqOutput, + self.resume_memory_occupation_communicator.handle_recv, + ), + ( + SlowDownReqOutput, + self.slow_down_communicator.handle_recv, + ), + ( + ClearHiCacheReqOutput, + self.clear_hicache_storage_communicator.handle_recv, + ), + ( + FlushCacheReqOutput, + self.flush_cache_communicator.handle_recv, + ), + ( + ProfileReqOutput, + self.profile_communicator.handle_recv, + ), + ( + GetInternalStateReqOutput, + self.get_internal_state_communicator.handle_recv, + ), + ( + SetInternalStateReqOutput, + self.set_internal_state_communicator.handle_recv, + ), + ( + ExpertDistributionReqOutput, + self.expert_distribution_communicator.handle_recv, + ), + ( + LoRAUpdateResult, + self.update_lora_adapter_communicator.handle_recv, + ), + ] + ) + + async def flush_cache(self: TokenizerManager) -> FlushCacheReqOutput: + return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] + + async def clear_hicache_storage(self: TokenizerManager) -> ClearHiCacheReqOutput: + """Clear the hierarchical cache storage.""" + # Delegate to the scheduler to handle HiCacheStorage clearing + return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[ + 0 + ] + + async def start_profile( + self: TokenizerManager, + output_dir: Optional[str] = None, + start_step: Optional[int] = None, + num_steps: Optional[int] = None, + activities: Optional[List[str]] = None, + with_stack: Optional[bool] = None, + record_shapes: Optional[bool] = None, + profile_by_stage: bool = False, + ): + self.auto_create_handle_loop() + env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true") + with_stack = False if with_stack is False or env_with_stack is False else True + req = ProfileReq( + type=ProfileReqType.START_PROFILE, + output_dir=output_dir, + start_step=start_step, + num_steps=num_steps, + activities=activities, + with_stack=with_stack, + record_shapes=record_shapes, + profile_by_stage=profile_by_stage, + profile_id=str(time.time()), + ) + return await self._execute_profile(req) + + async def stop_profile(self: TokenizerManager): + self.auto_create_handle_loop() + req = ProfileReq(type=ProfileReqType.STOP_PROFILE) + return await self._execute_profile(req) + + async def _execute_profile(self: TokenizerManager, req: ProfileReq): + result = (await self.profile_communicator(req))[0] + if not result.success: + raise RuntimeError(result.message) + return result + + async def start_expert_distribution_record(self: TokenizerManager): + self.auto_create_handle_loop() + await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD) + + async def stop_expert_distribution_record(self: TokenizerManager): + self.auto_create_handle_loop() + await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD) + + async def dump_expert_distribution_record(self: TokenizerManager): + self.auto_create_handle_loop() + await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) + + async def init_weights_update_group( + self: TokenizerManager, + obj: InitWeightsUpdateGroupReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = (await self.init_weights_update_group_communicator(obj))[0] + return result.success, result.message + + async def update_weights_from_distributed( + self: TokenizerManager, + obj: UpdateWeightsFromDistributedReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 or self.server_args.enable_dp_attention + ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" + + if obj.abort_all_requests: + self.abort_request(abort_all=True) + + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_distributed_communicator(obj))[0] + return result.success, result.message + + async def update_weights_from_tensor( + self: TokenizerManager, + obj: UpdateWeightsFromTensorReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 or self.server_args.enable_dp_attention + ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor" + + if obj.abort_all_requests: + self.abort_request(abort_all=True) + + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_tensor_communicator(obj))[0] + return result.success, result.message + + async def load_lora_adapter( + self: TokenizerManager, + obj: LoadLoRAAdapterReqInput, + _: Optional[fastapi.Request] = None, + ) -> LoadLoRAAdapterReqOutput: + self.auto_create_handle_loop() + + try: + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) + + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works + # with dp_size > 1. + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for dynamic lora loading" + logger.info( + "Start load Lora adapter. Lora name=%s, path=%s", + obj.lora_name, + obj.lora_path, + ) + + async with self.lora_update_lock: + if ( + self.server_args.max_loaded_loras is not None + and self.lora_registry.num_registered_loras + >= self.server_args.max_loaded_loras + ): + raise ValueError( + f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. " + f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. " + "Please unload some LoRA adapters before loading new ones." + ) + + # Generate new uniquely identifiable LoRARef object. + new_adapter = LoRARef( + lora_name=obj.lora_name, + lora_path=obj.lora_path, + pinned=obj.pinned, + ) + + # Trigger the actual loading operation at the backend processes. + obj.lora_id = new_adapter.lora_id + result = (await self.update_lora_adapter_communicator(obj))[0] + + # Register the LoRA adapter only after loading is successful. + if result.success: + await self.lora_registry.register(new_adapter) + + return result + except ValueError as e: + return LoadLoRAAdapterReqOutput( + success=False, + error_message=str(e), + ) + + async def unload_lora_adapter( + self: TokenizerManager, + obj: UnloadLoRAAdapterReqInput, + _: Optional[fastapi.Request] = None, + ) -> UnloadLoRAAdapterReqOutput: + self.auto_create_handle_loop() + + try: + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) + + assert ( + obj.lora_name is not None + ), "lora_name must be provided to unload LoRA adapter" + + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works + # with dp_size > 1. + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for dynamic lora loading" + logger.info( + "Start unload Lora adapter. Lora name=%s", + obj.lora_name, + ) + + async with self.lora_update_lock: + # Unregister the LoRA adapter from the registry to stop new requests for this adapter + # from being started. + lora_id = await self.lora_registry.unregister(obj.lora_name) + obj.lora_id = lora_id + + # Initiate the actual unloading operation at the backend processes only after all + # ongoing requests using this LoRA adapter are finished. + await self.lora_registry.wait_for_unload(lora_id) + result = (await self.update_lora_adapter_communicator(obj))[0] + + return result + except ValueError as e: + return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e)) + + async def get_weights_by_name( + self: TokenizerManager, + obj: GetWeightsByNameReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + results = await self.get_weights_by_name_communicator(obj) + all_parameters = [r.parameter for r in results] + if self.server_args.dp_size == 1: + return all_parameters[0] + else: + return all_parameters + + async def release_memory_occupation( + self: TokenizerManager, + obj: ReleaseMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.release_memory_occupation_communicator(obj) + + async def resume_memory_occupation( + self: TokenizerManager, + obj: ResumeMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.resume_memory_occupation_communicator(obj) + + async def slow_down( + self: TokenizerManager, + obj: SlowDownReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.slow_down_communicator(obj) + + async def get_internal_state(self: TokenizerManager) -> List[Dict[Any, Any]]: + req = GetInternalStateReq() + responses: List[GetInternalStateReqOutput] = ( + await self.get_internal_state_communicator(req) + ) + # Many DP ranks + return [res.internal_state for res in responses] + + async def set_internal_state( + self: TokenizerManager, obj: SetInternalStateReq + ) -> List[bool]: + responses: List[SetInternalStateReqOutput] = ( + await self.set_internal_state_communicator(obj) + ) + return [res.updated for res in responses] + + async def get_load(self: TokenizerManager) -> dict: + # TODO(lsyin): fake load report server + if not self.current_load_lock.locked(): + async with self.current_load_lock: + internal_state = await self.get_internal_state() + self.current_load = internal_state[0]["load"] + return {"load": self.current_load} diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c00235587..4812ca180 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -31,19 +31,7 @@ from contextlib import nullcontext from datetime import datetime from enum import Enum from http import HTTPStatus -from typing import ( - Any, - Awaitable, - Deque, - Dict, - Generic, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union import fastapi import torch @@ -70,57 +58,26 @@ from sglang.srt.managers.io_struct import ( BatchTokenIDOut, BatchTokenizedEmbeddingReqInput, BatchTokenizedGenerateReqInput, - ClearHiCacheReqInput, - ClearHiCacheReqOutput, CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, - ExpertDistributionReq, - ExpertDistributionReqOutput, - FlushCacheReqInput, - FlushCacheReqOutput, FreezeGCReq, GenerateReqInput, - GetInternalStateReq, - GetInternalStateReqOutput, - GetWeightsByNameReqInput, - GetWeightsByNameReqOutput, HealthCheckOutput, - InitWeightsUpdateGroupReqInput, - InitWeightsUpdateGroupReqOutput, - LoadLoRAAdapterReqInput, - LoadLoRAAdapterReqOutput, - LoRAUpdateResult, MultiTokenizerWrapper, OpenSessionReqInput, OpenSessionReqOutput, - ProfileReq, - ProfileReqOutput, - ProfileReqType, - ReleaseMemoryOccupationReqInput, - ReleaseMemoryOccupationReqOutput, - ResumeMemoryOccupationReqInput, - ResumeMemoryOccupationReqOutput, SessionParams, - SetInternalStateReq, - SetInternalStateReqOutput, - SlowDownReqInput, - SlowDownReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, - UnloadLoRAAdapterReqInput, - UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromDistributedReqOutput, - UpdateWeightsFromTensorReqInput, - UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region +from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs @@ -177,7 +134,7 @@ class ReqState: output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) -class TokenizerManager: +class TokenizerManager(TokenizerCommunicatorMixin): """TokenizerManager is a process that tokenizes the text.""" def __init__( @@ -343,50 +300,6 @@ class TokenizerManager: if self.server_args.gc_warning_threshold_secs > 0.0: configure_gc_warning(self.server_args.gc_warning_threshold_secs) - # Communicators - self.init_weights_update_group_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.update_weights_from_distributed_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.update_weights_from_tensor_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.get_weights_by_name_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.release_memory_occupation_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.resume_memory_occupation_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.slow_down_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.flush_cache_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.clear_hicache_storage_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.profile_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.get_internal_state_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.set_internal_state_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.expert_distribution_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.update_lora_adapter_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self._result_dispatcher = TypeBasedDispatcher( [ ( @@ -404,70 +317,16 @@ class TokenizerManager: UpdateWeightFromDiskReqOutput, self._handle_update_weights_from_disk_req_output, ), - ( - InitWeightsUpdateGroupReqOutput, - self.init_weights_update_group_communicator.handle_recv, - ), - ( - UpdateWeightsFromDistributedReqOutput, - self.update_weights_from_distributed_communicator.handle_recv, - ), - ( - UpdateWeightsFromTensorReqOutput, - self.update_weights_from_tensor_communicator.handle_recv, - ), - ( - GetWeightsByNameReqOutput, - self.get_weights_by_name_communicator.handle_recv, - ), - ( - ReleaseMemoryOccupationReqOutput, - self.release_memory_occupation_communicator.handle_recv, - ), - ( - ResumeMemoryOccupationReqOutput, - self.resume_memory_occupation_communicator.handle_recv, - ), - ( - SlowDownReqOutput, - self.slow_down_communicator.handle_recv, - ), - ( - ClearHiCacheReqOutput, - self.clear_hicache_storage_communicator.handle_recv, - ), - ( - FlushCacheReqOutput, - self.flush_cache_communicator.handle_recv, - ), - ( - ProfileReqOutput, - self.profile_communicator.handle_recv, - ), ( FreezeGCReq, lambda x: None, ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it. - ( - GetInternalStateReqOutput, - self.get_internal_state_communicator.handle_recv, - ), - ( - SetInternalStateReqOutput, - self.set_internal_state_communicator.handle_recv, - ), - ( - ExpertDistributionReqOutput, - self.expert_distribution_communicator.handle_recv, - ), - ( - LoRAUpdateResult, - self.update_lora_adapter_communicator.handle_recv, - ), (HealthCheckOutput, lambda x: None), ] ) + self.init_communicators(server_args) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -983,16 +842,6 @@ class TokenizerManager: except StopAsyncIteration: pass - async def flush_cache(self) -> FlushCacheReqOutput: - return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] - - async def clear_hicache_storage(self) -> ClearHiCacheReqOutput: - """Clear the hierarchical cache storage.""" - # Delegate to the scheduler to handle HiCacheStorage clearing - return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[ - 0 - ] - def abort_request(self, rid: str = "", abort_all: bool = False): if not abort_all and rid not in self.rid_to_state: return @@ -1002,55 +851,6 @@ class TokenizerManager: if self.enable_metrics: self.metrics_collector.observe_one_aborted_request() - async def start_profile( - self, - output_dir: Optional[str] = None, - start_step: Optional[int] = None, - num_steps: Optional[int] = None, - activities: Optional[List[str]] = None, - with_stack: Optional[bool] = None, - record_shapes: Optional[bool] = None, - profile_by_stage: bool = False, - ): - self.auto_create_handle_loop() - env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true") - with_stack = False if with_stack is False or env_with_stack is False else True - req = ProfileReq( - type=ProfileReqType.START_PROFILE, - output_dir=output_dir, - start_step=start_step, - num_steps=num_steps, - activities=activities, - with_stack=with_stack, - record_shapes=record_shapes, - profile_by_stage=profile_by_stage, - profile_id=str(time.time()), - ) - return await self._execute_profile(req) - - async def stop_profile(self): - self.auto_create_handle_loop() - req = ProfileReq(type=ProfileReqType.STOP_PROFILE) - return await self._execute_profile(req) - - async def _execute_profile(self, req: ProfileReq): - result = (await self.profile_communicator(req))[0] - if not result.success: - raise RuntimeError(result.message) - return result - - async def start_expert_distribution_record(self): - self.auto_create_handle_loop() - await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD) - - async def stop_expert_distribution_record(self): - self.auto_create_handle_loop() - await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD) - - async def dump_expert_distribution_record(self): - self.auto_create_handle_loop() - await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) - async def pause_generation(self): async with self.is_pause_cond: self.is_pause = True @@ -1111,191 +911,6 @@ class TokenizerManager: all_paused_requests = [r.num_paused_requests for r in result] return all_success, all_message, all_paused_requests - async def init_weights_update_group( - self, - obj: InitWeightsUpdateGroupReqInput, - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - result = (await self.init_weights_update_group_communicator(obj))[0] - return result.success, result.message - - async def update_weights_from_distributed( - self, - obj: UpdateWeightsFromDistributedReqInput, - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 or self.server_args.enable_dp_attention - ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" - - if obj.abort_all_requests: - self.abort_request(abort_all=True) - - # This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: - result = (await self.update_weights_from_distributed_communicator(obj))[0] - return result.success, result.message - - async def update_weights_from_tensor( - self, - obj: UpdateWeightsFromTensorReqInput, - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 or self.server_args.enable_dp_attention - ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor" - - if obj.abort_all_requests: - self.abort_request(abort_all=True) - - # This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: - result = (await self.update_weights_from_tensor_communicator(obj))[0] - return result.success, result.message - - async def load_lora_adapter( - self, - obj: LoadLoRAAdapterReqInput, - _: Optional[fastapi.Request] = None, - ) -> LoadLoRAAdapterReqOutput: - self.auto_create_handle_loop() - - try: - if not self.server_args.enable_lora: - raise ValueError( - "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." - ) - - # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works - # with dp_size > 1. - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for dynamic lora loading" - logger.info( - "Start load Lora adapter. Lora name=%s, path=%s", - obj.lora_name, - obj.lora_path, - ) - - async with self.lora_update_lock: - if ( - self.server_args.max_loaded_loras is not None - and self.lora_registry.num_registered_loras - >= self.server_args.max_loaded_loras - ): - raise ValueError( - f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. " - f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. " - "Please unload some LoRA adapters before loading new ones." - ) - - # Generate new uniquely identifiable LoRARef object. - new_adapter = LoRARef( - lora_name=obj.lora_name, - lora_path=obj.lora_path, - pinned=obj.pinned, - ) - - # Trigger the actual loading operation at the backend processes. - obj.lora_id = new_adapter.lora_id - result = (await self.update_lora_adapter_communicator(obj))[0] - - # Register the LoRA adapter only after loading is successful. - if result.success: - await self.lora_registry.register(new_adapter) - - return result - except ValueError as e: - return LoadLoRAAdapterReqOutput( - success=False, - error_message=str(e), - ) - - async def unload_lora_adapter( - self, - obj: UnloadLoRAAdapterReqInput, - _: Optional[fastapi.Request] = None, - ) -> UnloadLoRAAdapterReqOutput: - self.auto_create_handle_loop() - - try: - if not self.server_args.enable_lora: - raise ValueError( - "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." - ) - - assert ( - obj.lora_name is not None - ), "lora_name must be provided to unload LoRA adapter" - - # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works - # with dp_size > 1. - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for dynamic lora loading" - logger.info( - "Start unload Lora adapter. Lora name=%s", - obj.lora_name, - ) - - async with self.lora_update_lock: - # Unregister the LoRA adapter from the registry to stop new requests for this adapter - # from being started. - lora_id = await self.lora_registry.unregister(obj.lora_name) - obj.lora_id = lora_id - - # Initiate the actual unloading operation at the backend processes only after all - # ongoing requests using this LoRA adapter are finished. - await self.lora_registry.wait_for_unload(lora_id) - result = (await self.update_lora_adapter_communicator(obj))[0] - - return result - except ValueError as e: - return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e)) - - async def get_weights_by_name( - self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None - ): - self.auto_create_handle_loop() - results = await self.get_weights_by_name_communicator(obj) - all_parameters = [r.parameter for r in results] - if self.server_args.dp_size == 1: - return all_parameters[0] - else: - return all_parameters - - async def release_memory_occupation( - self, - obj: ReleaseMemoryOccupationReqInput, - request: Optional[fastapi.Request] = None, - ): - self.auto_create_handle_loop() - await self.release_memory_occupation_communicator(obj) - - async def resume_memory_occupation( - self, - obj: ResumeMemoryOccupationReqInput, - request: Optional[fastapi.Request] = None, - ): - self.auto_create_handle_loop() - await self.resume_memory_occupation_communicator(obj) - - async def slow_down( - self, - obj: SlowDownReqInput, - request: Optional[fastapi.Request] = None, - ): - self.auto_create_handle_loop() - await self.slow_down_communicator(obj) - async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): @@ -1320,28 +935,6 @@ class TokenizerManager: ): await self.send_to_scheduler.send_pyobj(obj) - async def get_internal_state(self) -> List[Dict[Any, Any]]: - req = GetInternalStateReq() - responses: List[GetInternalStateReqOutput] = ( - await self.get_internal_state_communicator(req) - ) - # Many DP ranks - return [res.internal_state for res in responses] - - async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]: - responses: List[SetInternalStateReqOutput] = ( - await self.set_internal_state_communicator(obj) - ) - return [res.updated for res in responses] - - async def get_load(self) -> dict: - # TODO(lsyin): fake load report server - if not self.current_load_lock.locked(): - async with self.current_load_lock: - internal_state = await self.get_internal_state() - self.current_load = internal_state[0]["load"] - return {"load": self.current_load} - def get_log_request_metadata(self): max_length = None skip_names = None @@ -2108,51 +1701,6 @@ class SignalHandler: kill_process_tree(os.getpid()) -T = TypeVar("T") - - -class _Communicator(Generic[T]): - """Note: The communicator now only run up to 1 in-flight request at any time.""" - - enable_multi_tokenizer = False - - def __init__(self, sender, fan_out: int): - self._sender = sender - self._fan_out = fan_out - self._result_event: Optional[asyncio.Event] = None - self._result_values: Optional[List[T]] = None - self._ready_queue: Deque[asyncio.Future] = deque() - - async def __call__(self, obj): - ready_event = asyncio.Event() - if self._result_event is not None or len(self._ready_queue) > 0: - self._ready_queue.append(ready_event) - await ready_event.wait() - assert self._result_event is None - assert self._result_values is None - - if obj: - if _Communicator.enable_multi_tokenizer: - obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj) - self._sender.send_pyobj(obj) - - self._result_event = asyncio.Event() - self._result_values = [] - await self._result_event.wait() - result_values = self._result_values - self._result_event = self._result_values = None - - if len(self._ready_queue) > 0: - self._ready_queue.popleft().set() - - return result_values - - def handle_recv(self, recv_obj: T): - self._result_values.append(recv_obj) - if len(self._result_values) == self._fan_out: - self._result_event.set() - - # Note: request abort handling logic # We should handle all of the following cases correctly. # diff --git a/python/sglang/srt/weight_sync/utils.py b/python/sglang/srt/weight_sync/utils.py index 8f3c8adb7..f308207e2 100644 --- a/python/sglang/srt/weight_sync/utils.py +++ b/python/sglang/srt/weight_sync/utils.py @@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor from sglang.srt.entrypoints.engine import Engine -from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput +from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.utils import MultiprocessingSerializer diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c84842e94..f6bf20c42 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -473,6 +473,10 @@ class TypeBasedDispatcher: def __init__(self, mapping: List[Tuple[Type, Callable]]): self._mapping = mapping + def __iadd__(self, other: "TypeBasedDispatcher"): + self._mapping.extend(other._mapping) + return self + def __call__(self, obj: Any): for ty, fn in self._mapping: if isinstance(obj, ty):