[1/N] DP-Refactor: move communicators into tokenizer_communicator_mixin (#10028)
This commit is contained in:
@@ -36,7 +36,8 @@ from sglang.srt.managers.io_struct import (
|
||||
MultiTokenizerRegisterReq,
|
||||
MultiTokenizerWrapper,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
|
||||
from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
491
python/sglang/srt/managers/tokenizer_communicator_mixin.py
Normal file
491
python/sglang/srt/managers/tokenizer_communicator_mixin.py
Normal file
@@ -0,0 +1,491 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Deque,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import fastapi
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
ClearHiCacheReqInput,
|
||||
ClearHiCacheReqOutput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReqInput,
|
||||
FlushCacheReqOutput,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
MultiTokenizerWrapper,
|
||||
ProfileReq,
|
||||
ProfileReqOutput,
|
||||
ProfileReqType,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
SlowDownReqInput,
|
||||
SlowDownReqOutput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.server_args import LoRARef, ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
from sglang.utils import TypeBasedDispatcher
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _Communicator(Generic[T]):
|
||||
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||
|
||||
enable_multi_tokenizer = False
|
||||
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
self._result_event: Optional[asyncio.Event] = None
|
||||
self._result_values: Optional[List[T]] = None
|
||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||
|
||||
async def __call__(self, obj):
|
||||
ready_event = asyncio.Event()
|
||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||
self._ready_queue.append(ready_event)
|
||||
await ready_event.wait()
|
||||
assert self._result_event is None
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
if _Communicator.enable_multi_tokenizer:
|
||||
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
self._result_values = []
|
||||
await self._result_event.wait()
|
||||
result_values = self._result_values
|
||||
self._result_event = self._result_values = None
|
||||
|
||||
if len(self._ready_queue) > 0:
|
||||
self._ready_queue.popleft().set()
|
||||
|
||||
return result_values
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_event.set()
|
||||
|
||||
|
||||
class TokenizerCommunicatorMixin:
|
||||
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
|
||||
|
||||
def init_communicators(self: TokenizerManager, server_args: ServerArgs):
|
||||
# Communicators
|
||||
self.init_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_distributed_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_tensor_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.release_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.slow_down_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.flush_cache_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.clear_hicache_storage_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.expert_distribution_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_lora_adapter_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher += self._get_communicator_dispatcher()
|
||||
|
||||
def _get_communicator_dispatcher(self: TokenizerManager):
|
||||
return TypeBasedDispatcher(
|
||||
[
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
self.update_weights_from_distributed_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
self.update_weights_from_tensor_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetWeightsByNameReqOutput,
|
||||
self.get_weights_by_name_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
self.release_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
self.resume_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SlowDownReqOutput,
|
||||
self.slow_down_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ClearHiCacheReqOutput,
|
||||
self.clear_hicache_storage_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
FlushCacheReqOutput,
|
||||
self.flush_cache_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ProfileReqOutput,
|
||||
self.profile_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ExpertDistributionReqOutput,
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
LoRAUpdateResult,
|
||||
self.update_lora_adapter_communicator.handle_recv,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def flush_cache(self: TokenizerManager) -> FlushCacheReqOutput:
|
||||
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
||||
|
||||
async def clear_hicache_storage(self: TokenizerManager) -> ClearHiCacheReqOutput:
|
||||
"""Clear the hierarchical cache storage."""
|
||||
# Delegate to the scheduler to handle HiCacheStorage clearing
|
||||
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
||||
0
|
||||
]
|
||||
|
||||
async def start_profile(
|
||||
self: TokenizerManager,
|
||||
output_dir: Optional[str] = None,
|
||||
start_step: Optional[int] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
activities: Optional[List[str]] = None,
|
||||
with_stack: Optional[bool] = None,
|
||||
record_shapes: Optional[bool] = None,
|
||||
profile_by_stage: bool = False,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
||||
with_stack = False if with_stack is False or env_with_stack is False else True
|
||||
req = ProfileReq(
|
||||
type=ProfileReqType.START_PROFILE,
|
||||
output_dir=output_dir,
|
||||
start_step=start_step,
|
||||
num_steps=num_steps,
|
||||
activities=activities,
|
||||
with_stack=with_stack,
|
||||
record_shapes=record_shapes,
|
||||
profile_by_stage=profile_by_stage,
|
||||
profile_id=str(time.time()),
|
||||
)
|
||||
return await self._execute_profile(req)
|
||||
|
||||
async def stop_profile(self: TokenizerManager):
|
||||
self.auto_create_handle_loop()
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
return await self._execute_profile(req)
|
||||
|
||||
async def _execute_profile(self: TokenizerManager, req: ProfileReq):
|
||||
result = (await self.profile_communicator(req))[0]
|
||||
if not result.success:
|
||||
raise RuntimeError(result.message)
|
||||
return result
|
||||
|
||||
async def start_expert_distribution_record(self: TokenizerManager):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
||||
|
||||
async def stop_expert_distribution_record(self: TokenizerManager):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
||||
|
||||
async def dump_expert_distribution_record(self: TokenizerManager):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
|
||||
async def init_weights_update_group(
|
||||
self: TokenizerManager,
|
||||
obj: InitWeightsUpdateGroupReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
result = (await self.init_weights_update_group_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_distributed(
|
||||
self: TokenizerManager,
|
||||
obj: UpdateWeightsFromDistributedReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_tensor(
|
||||
self: TokenizerManager,
|
||||
obj: UpdateWeightsFromTensorReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def load_lora_adapter(
|
||||
self: TokenizerManager,
|
||||
obj: LoadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
try:
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
)
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start load Lora adapter. Lora name=%s, path=%s",
|
||||
obj.lora_name,
|
||||
obj.lora_path,
|
||||
)
|
||||
|
||||
async with self.lora_update_lock:
|
||||
if (
|
||||
self.server_args.max_loaded_loras is not None
|
||||
and self.lora_registry.num_registered_loras
|
||||
>= self.server_args.max_loaded_loras
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
||||
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
||||
"Please unload some LoRA adapters before loading new ones."
|
||||
)
|
||||
|
||||
# Generate new uniquely identifiable LoRARef object.
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
pinned=obj.pinned,
|
||||
)
|
||||
|
||||
# Trigger the actual loading operation at the backend processes.
|
||||
obj.lora_id = new_adapter.lora_id
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
# Register the LoRA adapter only after loading is successful.
|
||||
if result.success:
|
||||
await self.lora_registry.register(new_adapter)
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
return LoadLoRAAdapterReqOutput(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self: TokenizerManager,
|
||||
obj: UnloadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
try:
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
)
|
||||
|
||||
assert (
|
||||
obj.lora_name is not None
|
||||
), "lora_name must be provided to unload LoRA adapter"
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start unload Lora adapter. Lora name=%s",
|
||||
obj.lora_name,
|
||||
)
|
||||
|
||||
async with self.lora_update_lock:
|
||||
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
||||
# from being started.
|
||||
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
obj.lora_id = lora_id
|
||||
|
||||
# Initiate the actual unloading operation at the backend processes only after all
|
||||
# ongoing requests using this LoRA adapter are finished.
|
||||
await self.lora_registry.wait_for_unload(lora_id)
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
||||
|
||||
async def get_weights_by_name(
|
||||
self: TokenizerManager,
|
||||
obj: GetWeightsByNameReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
results = await self.get_weights_by_name_communicator(obj)
|
||||
all_parameters = [r.parameter for r in results]
|
||||
if self.server_args.dp_size == 1:
|
||||
return all_parameters[0]
|
||||
else:
|
||||
return all_parameters
|
||||
|
||||
async def release_memory_occupation(
|
||||
self: TokenizerManager,
|
||||
obj: ReleaseMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.release_memory_occupation_communicator(obj)
|
||||
|
||||
async def resume_memory_occupation(
|
||||
self: TokenizerManager,
|
||||
obj: ResumeMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.resume_memory_occupation_communicator(obj)
|
||||
|
||||
async def slow_down(
|
||||
self: TokenizerManager,
|
||||
obj: SlowDownReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.slow_down_communicator(obj)
|
||||
|
||||
async def get_internal_state(self: TokenizerManager) -> List[Dict[Any, Any]]:
|
||||
req = GetInternalStateReq()
|
||||
responses: List[GetInternalStateReqOutput] = (
|
||||
await self.get_internal_state_communicator(req)
|
||||
)
|
||||
# Many DP ranks
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
async def set_internal_state(
|
||||
self: TokenizerManager, obj: SetInternalStateReq
|
||||
) -> List[bool]:
|
||||
responses: List[SetInternalStateReqOutput] = (
|
||||
await self.set_internal_state_communicator(obj)
|
||||
)
|
||||
return [res.updated for res in responses]
|
||||
|
||||
async def get_load(self: TokenizerManager) -> dict:
|
||||
# TODO(lsyin): fake load report server
|
||||
if not self.current_load_lock.locked():
|
||||
async with self.current_load_lock:
|
||||
internal_state = await self.get_internal_state()
|
||||
self.current_load = internal_state[0]["load"]
|
||||
return {"load": self.current_load}
|
||||
@@ -31,19 +31,7 @@ from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Deque,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import torch
|
||||
@@ -70,57 +58,26 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
BatchTokenizedEmbeddingReqInput,
|
||||
BatchTokenizedGenerateReqInput,
|
||||
ClearHiCacheReqInput,
|
||||
ClearHiCacheReqOutput,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
ExpertDistributionReq,
|
||||
ExpertDistributionReqOutput,
|
||||
FlushCacheReqInput,
|
||||
FlushCacheReqOutput,
|
||||
FreezeGCReq,
|
||||
GenerateReqInput,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
MultiTokenizerWrapper,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
ProfileReqOutput,
|
||||
ProfileReqType,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SessionParams,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
SlowDownReqInput,
|
||||
SlowDownReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
||||
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
||||
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -177,7 +134,7 @@ class ReqState:
|
||||
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
def __init__(
|
||||
@@ -343,50 +300,6 @@ class TokenizerManager:
|
||||
if self.server_args.gc_warning_threshold_secs > 0.0:
|
||||
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
||||
|
||||
# Communicators
|
||||
self.init_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_distributed_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_tensor_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.release_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.slow_down_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.flush_cache_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.clear_hicache_storage_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.expert_distribution_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_lora_adapter_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(
|
||||
@@ -404,70 +317,16 @@ class TokenizerManager:
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
self._handle_update_weights_from_disk_req_output,
|
||||
),
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
self.update_weights_from_distributed_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
self.update_weights_from_tensor_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetWeightsByNameReqOutput,
|
||||
self.get_weights_by_name_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
self.release_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
self.resume_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SlowDownReqOutput,
|
||||
self.slow_down_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ClearHiCacheReqOutput,
|
||||
self.clear_hicache_storage_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
FlushCacheReqOutput,
|
||||
self.flush_cache_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ProfileReqOutput,
|
||||
self.profile_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
FreezeGCReq,
|
||||
lambda x: None,
|
||||
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
||||
(
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ExpertDistributionReqOutput,
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
LoRAUpdateResult,
|
||||
self.update_lora_adapter_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
|
||||
self.init_communicators(server_args)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -983,16 +842,6 @@ class TokenizerManager:
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
async def flush_cache(self) -> FlushCacheReqOutput:
|
||||
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
||||
|
||||
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
||||
"""Clear the hierarchical cache storage."""
|
||||
# Delegate to the scheduler to handle HiCacheStorage clearing
|
||||
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
||||
0
|
||||
]
|
||||
|
||||
def abort_request(self, rid: str = "", abort_all: bool = False):
|
||||
if not abort_all and rid not in self.rid_to_state:
|
||||
return
|
||||
@@ -1002,55 +851,6 @@ class TokenizerManager:
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector.observe_one_aborted_request()
|
||||
|
||||
async def start_profile(
|
||||
self,
|
||||
output_dir: Optional[str] = None,
|
||||
start_step: Optional[int] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
activities: Optional[List[str]] = None,
|
||||
with_stack: Optional[bool] = None,
|
||||
record_shapes: Optional[bool] = None,
|
||||
profile_by_stage: bool = False,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
||||
with_stack = False if with_stack is False or env_with_stack is False else True
|
||||
req = ProfileReq(
|
||||
type=ProfileReqType.START_PROFILE,
|
||||
output_dir=output_dir,
|
||||
start_step=start_step,
|
||||
num_steps=num_steps,
|
||||
activities=activities,
|
||||
with_stack=with_stack,
|
||||
record_shapes=record_shapes,
|
||||
profile_by_stage=profile_by_stage,
|
||||
profile_id=str(time.time()),
|
||||
)
|
||||
return await self._execute_profile(req)
|
||||
|
||||
async def stop_profile(self):
|
||||
self.auto_create_handle_loop()
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
return await self._execute_profile(req)
|
||||
|
||||
async def _execute_profile(self, req: ProfileReq):
|
||||
result = (await self.profile_communicator(req))[0]
|
||||
if not result.success:
|
||||
raise RuntimeError(result.message)
|
||||
return result
|
||||
|
||||
async def start_expert_distribution_record(self):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
||||
|
||||
async def stop_expert_distribution_record(self):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
||||
|
||||
async def dump_expert_distribution_record(self):
|
||||
self.auto_create_handle_loop()
|
||||
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
||||
|
||||
async def pause_generation(self):
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = True
|
||||
@@ -1111,191 +911,6 @@ class TokenizerManager:
|
||||
all_paused_requests = [r.num_paused_requests for r in result]
|
||||
return all_success, all_message, all_paused_requests
|
||||
|
||||
async def init_weights_update_group(
|
||||
self,
|
||||
obj: InitWeightsUpdateGroupReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
result = (await self.init_weights_update_group_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_distributed(
|
||||
self,
|
||||
obj: UpdateWeightsFromDistributedReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_tensor(
|
||||
self,
|
||||
obj: UpdateWeightsFromTensorReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
obj: LoadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
try:
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
)
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start load Lora adapter. Lora name=%s, path=%s",
|
||||
obj.lora_name,
|
||||
obj.lora_path,
|
||||
)
|
||||
|
||||
async with self.lora_update_lock:
|
||||
if (
|
||||
self.server_args.max_loaded_loras is not None
|
||||
and self.lora_registry.num_registered_loras
|
||||
>= self.server_args.max_loaded_loras
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
||||
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
||||
"Please unload some LoRA adapters before loading new ones."
|
||||
)
|
||||
|
||||
# Generate new uniquely identifiable LoRARef object.
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
pinned=obj.pinned,
|
||||
)
|
||||
|
||||
# Trigger the actual loading operation at the backend processes.
|
||||
obj.lora_id = new_adapter.lora_id
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
# Register the LoRA adapter only after loading is successful.
|
||||
if result.success:
|
||||
await self.lora_registry.register(new_adapter)
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
return LoadLoRAAdapterReqOutput(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
obj: UnloadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
try:
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
)
|
||||
|
||||
assert (
|
||||
obj.lora_name is not None
|
||||
), "lora_name must be provided to unload LoRA adapter"
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start unload Lora adapter. Lora name=%s",
|
||||
obj.lora_name,
|
||||
)
|
||||
|
||||
async with self.lora_update_lock:
|
||||
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
||||
# from being started.
|
||||
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
obj.lora_id = lora_id
|
||||
|
||||
# Initiate the actual unloading operation at the backend processes only after all
|
||||
# ongoing requests using this LoRA adapter are finished.
|
||||
await self.lora_registry.wait_for_unload(lora_id)
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
results = await self.get_weights_by_name_communicator(obj)
|
||||
all_parameters = [r.parameter for r in results]
|
||||
if self.server_args.dp_size == 1:
|
||||
return all_parameters[0]
|
||||
else:
|
||||
return all_parameters
|
||||
|
||||
async def release_memory_occupation(
|
||||
self,
|
||||
obj: ReleaseMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.release_memory_occupation_communicator(obj)
|
||||
|
||||
async def resume_memory_occupation(
|
||||
self,
|
||||
obj: ResumeMemoryOccupationReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.resume_memory_occupation_communicator(obj)
|
||||
|
||||
async def slow_down(
|
||||
self,
|
||||
obj: SlowDownReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
self.auto_create_handle_loop()
|
||||
await self.slow_down_communicator(obj)
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -1320,28 +935,6 @@ class TokenizerManager:
|
||||
):
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
||||
req = GetInternalStateReq()
|
||||
responses: List[GetInternalStateReqOutput] = (
|
||||
await self.get_internal_state_communicator(req)
|
||||
)
|
||||
# Many DP ranks
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
||||
responses: List[SetInternalStateReqOutput] = (
|
||||
await self.set_internal_state_communicator(obj)
|
||||
)
|
||||
return [res.updated for res in responses]
|
||||
|
||||
async def get_load(self) -> dict:
|
||||
# TODO(lsyin): fake load report server
|
||||
if not self.current_load_lock.locked():
|
||||
async with self.current_load_lock:
|
||||
internal_state = await self.get_internal_state()
|
||||
self.current_load = internal_state[0]["load"]
|
||||
return {"load": self.current_load}
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
@@ -2108,51 +1701,6 @@ class SignalHandler:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _Communicator(Generic[T]):
|
||||
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||
|
||||
enable_multi_tokenizer = False
|
||||
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
self._result_event: Optional[asyncio.Event] = None
|
||||
self._result_values: Optional[List[T]] = None
|
||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||
|
||||
async def __call__(self, obj):
|
||||
ready_event = asyncio.Event()
|
||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||
self._ready_queue.append(ready_event)
|
||||
await ready_event.wait()
|
||||
assert self._result_event is None
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
if _Communicator.enable_multi_tokenizer:
|
||||
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
self._result_values = []
|
||||
await self._result_event.wait()
|
||||
result_values = self._result_values
|
||||
self._result_event = self._result_values = None
|
||||
|
||||
if len(self._ready_queue) > 0:
|
||||
self._ready_queue.popleft().set()
|
||||
|
||||
return result_values
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_event.set()
|
||||
|
||||
|
||||
# Note: request abort handling logic
|
||||
# We should handle all of the following cases correctly.
|
||||
#
|
||||
|
||||
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
|
||||
from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
from sglang.srt.utils import MultiprocessingSerializer
|
||||
|
||||
|
||||
@@ -473,6 +473,10 @@ class TypeBasedDispatcher:
|
||||
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
||||
self._mapping = mapping
|
||||
|
||||
def __iadd__(self, other: "TypeBasedDispatcher"):
|
||||
self._mapping.extend(other._mapping)
|
||||
return self
|
||||
|
||||
def __call__(self, obj: Any):
|
||||
for ty, fn in self._mapping:
|
||||
if isinstance(obj, ty):
|
||||
|
||||
Reference in New Issue
Block a user