Signed-off-by: Anqi Shen <amy.saq@antgroup.com> Co-authored-by: Chayenne <74843776+zhaochenyang20@users.noreply.github.com>
538 lines
20 KiB
Python
538 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
from collections import deque
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Deque,
|
|
Dict,
|
|
Generic,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
TypeVar,
|
|
)
|
|
|
|
import fastapi
|
|
|
|
from sglang.srt.managers.io_struct import (
|
|
ClearHiCacheReqInput,
|
|
ClearHiCacheReqOutput,
|
|
ExpertDistributionReq,
|
|
ExpertDistributionReqOutput,
|
|
FlushCacheReqInput,
|
|
FlushCacheReqOutput,
|
|
GetInternalStateReq,
|
|
GetInternalStateReqOutput,
|
|
GetWeightsByNameReqInput,
|
|
GetWeightsByNameReqOutput,
|
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
|
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
InitWeightsUpdateGroupReqOutput,
|
|
LoadLoRAAdapterReqInput,
|
|
LoadLoRAAdapterReqOutput,
|
|
LoRAUpdateResult,
|
|
MultiTokenizerWrapper,
|
|
ProfileReq,
|
|
ProfileReqOutput,
|
|
ProfileReqType,
|
|
ReleaseMemoryOccupationReqInput,
|
|
ReleaseMemoryOccupationReqOutput,
|
|
ResumeMemoryOccupationReqInput,
|
|
ResumeMemoryOccupationReqOutput,
|
|
SendWeightsToRemoteInstanceReqInput,
|
|
SendWeightsToRemoteInstanceReqOutput,
|
|
SetInternalStateReq,
|
|
SetInternalStateReqOutput,
|
|
SlowDownReqInput,
|
|
SlowDownReqOutput,
|
|
UnloadLoRAAdapterReqInput,
|
|
UnloadLoRAAdapterReqOutput,
|
|
UpdateWeightsFromDistributedReqInput,
|
|
UpdateWeightsFromDistributedReqOutput,
|
|
UpdateWeightsFromTensorReqInput,
|
|
UpdateWeightsFromTensorReqOutput,
|
|
)
|
|
from sglang.srt.server_args import LoRARef, ServerArgs
|
|
from sglang.srt.utils import get_bool_env_var
|
|
from sglang.utils import TypeBasedDispatcher
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
|
|
T = TypeVar("T")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class _Communicator(Generic[T]):
|
|
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
|
|
|
enable_multi_tokenizer = False
|
|
|
|
def __init__(self, sender, fan_out: int):
|
|
self._sender = sender
|
|
self._fan_out = fan_out
|
|
self._result_event: Optional[asyncio.Event] = None
|
|
self._result_values: Optional[List[T]] = None
|
|
self._ready_queue: Deque[asyncio.Future] = deque()
|
|
|
|
async def __call__(self, obj):
|
|
ready_event = asyncio.Event()
|
|
if self._result_event is not None or len(self._ready_queue) > 0:
|
|
self._ready_queue.append(ready_event)
|
|
await ready_event.wait()
|
|
assert self._result_event is None
|
|
assert self._result_values is None
|
|
|
|
if obj:
|
|
if _Communicator.enable_multi_tokenizer:
|
|
obj = MultiTokenizerWrapper(worker_id=os.getpid(), obj=obj)
|
|
self._sender.send_pyobj(obj)
|
|
|
|
self._result_event = asyncio.Event()
|
|
self._result_values = []
|
|
await self._result_event.wait()
|
|
result_values = self._result_values
|
|
self._result_event = self._result_values = None
|
|
|
|
if len(self._ready_queue) > 0:
|
|
self._ready_queue.popleft().set()
|
|
|
|
return result_values
|
|
|
|
def handle_recv(self, recv_obj: T):
|
|
self._result_values.append(recv_obj)
|
|
if len(self._result_values) == self._fan_out:
|
|
self._result_event.set()
|
|
|
|
|
|
class TokenizerCommunicatorMixin:
|
|
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
|
|
|
|
def init_communicators(self: TokenizerManager, server_args: ServerArgs):
|
|
# Communicators
|
|
self.init_weights_update_group_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.update_weights_from_distributed_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.init_weights_send_group_for_remote_instance_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.send_weights_to_remote_instance_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.update_weights_from_tensor_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.get_weights_by_name_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.release_memory_occupation_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.resume_memory_occupation_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.slow_down_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.flush_cache_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.clear_hicache_storage_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.profile_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.get_internal_state_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.set_internal_state_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.expert_distribution_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
self.update_lora_adapter_communicator = _Communicator(
|
|
self.send_to_scheduler, server_args.dp_size
|
|
)
|
|
|
|
self._result_dispatcher += self._get_communicator_dispatcher()
|
|
|
|
def _get_communicator_dispatcher(self: TokenizerManager):
|
|
return TypeBasedDispatcher(
|
|
[
|
|
(
|
|
InitWeightsUpdateGroupReqOutput,
|
|
self.init_weights_update_group_communicator.handle_recv,
|
|
),
|
|
(
|
|
UpdateWeightsFromDistributedReqOutput,
|
|
self.update_weights_from_distributed_communicator.handle_recv,
|
|
),
|
|
(
|
|
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
|
self.init_weights_send_group_for_remote_instance_communicator.handle_recv,
|
|
),
|
|
(
|
|
SendWeightsToRemoteInstanceReqOutput,
|
|
self.send_weights_to_remote_instance_communicator.handle_recv,
|
|
),
|
|
(
|
|
UpdateWeightsFromTensorReqOutput,
|
|
self.update_weights_from_tensor_communicator.handle_recv,
|
|
),
|
|
(
|
|
GetWeightsByNameReqOutput,
|
|
self.get_weights_by_name_communicator.handle_recv,
|
|
),
|
|
(
|
|
ReleaseMemoryOccupationReqOutput,
|
|
self.release_memory_occupation_communicator.handle_recv,
|
|
),
|
|
(
|
|
ResumeMemoryOccupationReqOutput,
|
|
self.resume_memory_occupation_communicator.handle_recv,
|
|
),
|
|
(
|
|
SlowDownReqOutput,
|
|
self.slow_down_communicator.handle_recv,
|
|
),
|
|
(
|
|
ClearHiCacheReqOutput,
|
|
self.clear_hicache_storage_communicator.handle_recv,
|
|
),
|
|
(
|
|
FlushCacheReqOutput,
|
|
self.flush_cache_communicator.handle_recv,
|
|
),
|
|
(
|
|
ProfileReqOutput,
|
|
self.profile_communicator.handle_recv,
|
|
),
|
|
(
|
|
GetInternalStateReqOutput,
|
|
self.get_internal_state_communicator.handle_recv,
|
|
),
|
|
(
|
|
SetInternalStateReqOutput,
|
|
self.set_internal_state_communicator.handle_recv,
|
|
),
|
|
(
|
|
ExpertDistributionReqOutput,
|
|
self.expert_distribution_communicator.handle_recv,
|
|
),
|
|
(
|
|
LoRAUpdateResult,
|
|
self.update_lora_adapter_communicator.handle_recv,
|
|
),
|
|
]
|
|
)
|
|
|
|
async def flush_cache(self: TokenizerManager) -> FlushCacheReqOutput:
|
|
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
|
|
|
async def clear_hicache_storage(self: TokenizerManager) -> ClearHiCacheReqOutput:
|
|
"""Clear the hierarchical cache storage."""
|
|
# Delegate to the scheduler to handle HiCacheStorage clearing
|
|
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
|
0
|
|
]
|
|
|
|
async def start_profile(
|
|
self: TokenizerManager,
|
|
output_dir: Optional[str] = None,
|
|
start_step: Optional[int] = None,
|
|
num_steps: Optional[int] = None,
|
|
activities: Optional[List[str]] = None,
|
|
with_stack: Optional[bool] = None,
|
|
record_shapes: Optional[bool] = None,
|
|
profile_by_stage: bool = False,
|
|
):
|
|
self.auto_create_handle_loop()
|
|
env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
|
|
with_stack = False if with_stack is False or env_with_stack is False else True
|
|
req = ProfileReq(
|
|
type=ProfileReqType.START_PROFILE,
|
|
output_dir=output_dir,
|
|
start_step=start_step,
|
|
num_steps=num_steps,
|
|
activities=activities,
|
|
with_stack=with_stack,
|
|
record_shapes=record_shapes,
|
|
profile_by_stage=profile_by_stage,
|
|
profile_id=str(time.time()),
|
|
)
|
|
return await self._execute_profile(req)
|
|
|
|
async def stop_profile(self: TokenizerManager):
|
|
self.auto_create_handle_loop()
|
|
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
|
return await self._execute_profile(req)
|
|
|
|
async def _execute_profile(self: TokenizerManager, req: ProfileReq):
|
|
result = (await self.profile_communicator(req))[0]
|
|
if not result.success:
|
|
raise RuntimeError(result.message)
|
|
return result
|
|
|
|
async def start_expert_distribution_record(self: TokenizerManager):
|
|
self.auto_create_handle_loop()
|
|
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
|
|
|
async def stop_expert_distribution_record(self: TokenizerManager):
|
|
self.auto_create_handle_loop()
|
|
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
|
|
|
async def dump_expert_distribution_record(self: TokenizerManager):
|
|
self.auto_create_handle_loop()
|
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
|
|
|
async def init_weights_update_group(
|
|
self: TokenizerManager,
|
|
obj: InitWeightsUpdateGroupReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
assert (
|
|
self.server_args.dp_size == 1
|
|
), "dp_size must be 1 for init parameter update group"
|
|
result = (await self.init_weights_update_group_communicator(obj))[0]
|
|
return result.success, result.message
|
|
|
|
async def update_weights_from_distributed(
|
|
self: TokenizerManager,
|
|
obj: UpdateWeightsFromDistributedReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
assert (
|
|
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
|
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
|
|
|
if obj.abort_all_requests:
|
|
self.abort_request(abort_all=True)
|
|
|
|
# This means that weight sync
|
|
# cannot run while requests are in progress.
|
|
async with self.model_update_lock.writer_lock:
|
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
|
return result.success, result.message
|
|
|
|
async def init_weights_send_group_for_remote_instance(
|
|
self,
|
|
obj: InitWeightsSendGroupForRemoteInstanceReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
# TODO: support DP
|
|
assert (
|
|
self.server_args.dp_size == 1
|
|
), "dp_size must be 1 for init_weights_send_group_for_remote_instance"
|
|
result = (
|
|
await self.init_weights_send_group_for_remote_instance_communicator(obj)
|
|
)[0]
|
|
return result.success, result.message
|
|
|
|
async def send_weights_to_remote_instance(
|
|
self,
|
|
obj: SendWeightsToRemoteInstanceReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
# TODO: support DP
|
|
assert (
|
|
self.server_args.dp_size == 1
|
|
), "dp_size must be 1 for send_weights_to_remote_instance"
|
|
result = (await self.send_weights_to_remote_instance_communicator(obj))[0]
|
|
return result.success, result.message
|
|
|
|
async def update_weights_from_tensor(
|
|
self: TokenizerManager,
|
|
obj: UpdateWeightsFromTensorReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
assert (
|
|
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
|
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
|
|
|
if obj.abort_all_requests:
|
|
self.abort_request(abort_all=True)
|
|
|
|
# This means that weight sync
|
|
# cannot run while requests are in progress.
|
|
async with self.model_update_lock.writer_lock:
|
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
return result.success, result.message
|
|
|
|
async def load_lora_adapter(
|
|
self: TokenizerManager,
|
|
obj: LoadLoRAAdapterReqInput,
|
|
_: Optional[fastapi.Request] = None,
|
|
) -> LoadLoRAAdapterReqOutput:
|
|
self.auto_create_handle_loop()
|
|
|
|
try:
|
|
if not self.server_args.enable_lora:
|
|
raise ValueError(
|
|
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
|
)
|
|
|
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
|
# with dp_size > 1.
|
|
assert (
|
|
self.server_args.dp_size == 1
|
|
), "dp_size must be 1 for dynamic lora loading"
|
|
logger.info(
|
|
"Start load Lora adapter. Lora name=%s, path=%s",
|
|
obj.lora_name,
|
|
obj.lora_path,
|
|
)
|
|
|
|
async with self.lora_update_lock:
|
|
if (
|
|
self.server_args.max_loaded_loras is not None
|
|
and self.lora_registry.num_registered_loras
|
|
>= self.server_args.max_loaded_loras
|
|
):
|
|
raise ValueError(
|
|
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
|
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
|
"Please unload some LoRA adapters before loading new ones."
|
|
)
|
|
|
|
# Generate new uniquely identifiable LoRARef object.
|
|
new_adapter = LoRARef(
|
|
lora_name=obj.lora_name,
|
|
lora_path=obj.lora_path,
|
|
pinned=obj.pinned,
|
|
)
|
|
|
|
# Trigger the actual loading operation at the backend processes.
|
|
obj.lora_id = new_adapter.lora_id
|
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
|
|
|
# Register the LoRA adapter only after loading is successful.
|
|
if result.success:
|
|
await self.lora_registry.register(new_adapter)
|
|
|
|
return result
|
|
except ValueError as e:
|
|
return LoadLoRAAdapterReqOutput(
|
|
success=False,
|
|
error_message=str(e),
|
|
)
|
|
|
|
async def unload_lora_adapter(
|
|
self: TokenizerManager,
|
|
obj: UnloadLoRAAdapterReqInput,
|
|
_: Optional[fastapi.Request] = None,
|
|
) -> UnloadLoRAAdapterReqOutput:
|
|
self.auto_create_handle_loop()
|
|
|
|
try:
|
|
if not self.server_args.enable_lora:
|
|
raise ValueError(
|
|
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
|
)
|
|
|
|
assert (
|
|
obj.lora_name is not None
|
|
), "lora_name must be provided to unload LoRA adapter"
|
|
|
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
|
# with dp_size > 1.
|
|
assert (
|
|
self.server_args.dp_size == 1
|
|
), "dp_size must be 1 for dynamic lora loading"
|
|
logger.info(
|
|
"Start unload Lora adapter. Lora name=%s",
|
|
obj.lora_name,
|
|
)
|
|
|
|
async with self.lora_update_lock:
|
|
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
|
# from being started.
|
|
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
|
obj.lora_id = lora_id
|
|
|
|
# Initiate the actual unloading operation at the backend processes only after all
|
|
# ongoing requests using this LoRA adapter are finished.
|
|
await self.lora_registry.wait_for_unload(lora_id)
|
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
|
|
|
return result
|
|
except ValueError as e:
|
|
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
|
|
|
async def get_weights_by_name(
|
|
self: TokenizerManager,
|
|
obj: GetWeightsByNameReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
):
|
|
self.auto_create_handle_loop()
|
|
results = await self.get_weights_by_name_communicator(obj)
|
|
all_parameters = [r.parameter for r in results]
|
|
if self.server_args.dp_size == 1:
|
|
return all_parameters[0]
|
|
else:
|
|
return all_parameters
|
|
|
|
async def release_memory_occupation(
|
|
self: TokenizerManager,
|
|
obj: ReleaseMemoryOccupationReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
):
|
|
self.auto_create_handle_loop()
|
|
await self.release_memory_occupation_communicator(obj)
|
|
|
|
async def resume_memory_occupation(
|
|
self: TokenizerManager,
|
|
obj: ResumeMemoryOccupationReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
):
|
|
self.auto_create_handle_loop()
|
|
await self.resume_memory_occupation_communicator(obj)
|
|
|
|
async def slow_down(
|
|
self: TokenizerManager,
|
|
obj: SlowDownReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
):
|
|
self.auto_create_handle_loop()
|
|
await self.slow_down_communicator(obj)
|
|
|
|
async def get_internal_state(self: TokenizerManager) -> List[Dict[Any, Any]]:
|
|
req = GetInternalStateReq()
|
|
responses: List[GetInternalStateReqOutput] = (
|
|
await self.get_internal_state_communicator(req)
|
|
)
|
|
# Many DP ranks
|
|
return [res.internal_state for res in responses]
|
|
|
|
async def set_internal_state(
|
|
self: TokenizerManager, obj: SetInternalStateReq
|
|
) -> List[bool]:
|
|
responses: List[SetInternalStateReqOutput] = (
|
|
await self.set_internal_state_communicator(obj)
|
|
)
|
|
return [res.updated for res in responses]
|
|
|
|
async def get_load(self: TokenizerManager) -> dict:
|
|
# TODO(lsyin): fake load report server
|
|
if not self.current_load_lock.locked():
|
|
async with self.current_load_lock:
|
|
internal_state = await self.get_internal_state()
|
|
self.current_load = internal_state[0]["load"]
|
|
return {"load": self.current_load}
|