2024-11-22 22:16:53 +08:00
|
|
|
# Copyright 2023-2024 SGLang Team
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ==============================================================================
|
2024-06-08 02:06:52 -07:00
|
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
2024-06-12 21:48:40 -07:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
import asyncio
|
2024-11-03 08:38:26 -08:00
|
|
|
import copy
|
2024-01-08 04:37:50 +00:00
|
|
|
import dataclasses
|
2024-05-12 06:41:32 -07:00
|
|
|
import logging
|
2024-01-08 04:37:50 +00:00
|
|
|
import os
|
2024-10-30 10:22:56 -07:00
|
|
|
import signal
|
|
|
|
|
import sys
|
2024-11-10 04:39:32 -08:00
|
|
|
import time
|
2024-11-20 00:36:53 -08:00
|
|
|
import uuid
|
2024-12-27 09:53:09 +08:00
|
|
|
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
import fastapi
|
2024-01-08 04:37:50 +00:00
|
|
|
import uvloop
|
|
|
|
|
import zmq
|
|
|
|
|
import zmq.asyncio
|
2024-05-20 18:41:21 -07:00
|
|
|
from fastapi import BackgroundTasks
|
2024-04-22 22:38:09 +08:00
|
|
|
|
2024-12-22 06:25:57 -08:00
|
|
|
from sglang.srt.aio_rwlock import RWLock
|
2024-11-03 12:25:39 -08:00
|
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
|
|
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
2024-09-29 18:52:43 -07:00
|
|
|
from sglang.srt.managers.image_processor import (
|
|
|
|
|
get_dummy_image_processor,
|
|
|
|
|
get_image_processor,
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
from sglang.srt.managers.io_struct import (
|
2024-05-17 05:49:31 -07:00
|
|
|
AbortReq,
|
2024-08-08 16:31:19 -07:00
|
|
|
BatchEmbeddingOut,
|
2024-01-08 04:37:50 +00:00
|
|
|
BatchStrOut,
|
2024-06-12 21:48:40 -07:00
|
|
|
BatchTokenIDOut,
|
2024-11-20 00:36:53 -08:00
|
|
|
CloseSessionReqInput,
|
2024-08-08 16:31:19 -07:00
|
|
|
EmbeddingReqInput,
|
2024-01-29 17:05:42 -08:00
|
|
|
FlushCacheReq,
|
2024-01-08 04:37:50 +00:00
|
|
|
GenerateReqInput,
|
2024-11-29 23:36:38 -08:00
|
|
|
GetWeightsByNameReqInput,
|
|
|
|
|
GetWeightsByNameReqOutput,
|
2024-12-01 23:23:18 -08:00
|
|
|
InitWeightsUpdateGroupReqInput,
|
|
|
|
|
InitWeightsUpdateGroupReqOutput,
|
2024-11-20 00:36:53 -08:00
|
|
|
OpenSessionReqInput,
|
|
|
|
|
OpenSessionReqOutput,
|
2024-10-11 17:34:25 +08:00
|
|
|
ProfileReq,
|
2024-08-08 16:31:19 -07:00
|
|
|
TokenizedEmbeddingReqInput,
|
2024-01-08 04:37:50 +00:00
|
|
|
TokenizedGenerateReqInput,
|
2024-11-29 17:17:00 -08:00
|
|
|
UpdateWeightFromDiskReqInput,
|
|
|
|
|
UpdateWeightFromDiskReqOutput,
|
2024-12-01 23:23:18 -08:00
|
|
|
UpdateWeightsFromDistributedReqInput,
|
|
|
|
|
UpdateWeightsFromDistributedReqOutput,
|
2024-12-29 05:30:27 +08:00
|
|
|
UpdateWeightsFromTensorReqInput,
|
|
|
|
|
UpdateWeightsFromTensorReqOutput,
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
2024-11-10 04:39:32 -08:00
|
|
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
2024-08-21 16:48:24 -07:00
|
|
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
2024-01-08 04:37:50 +00:00
|
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
2024-12-22 06:25:57 -08:00
|
|
|
from sglang.srt.utils import (
|
|
|
|
|
dataclass_to_string_truncated,
|
|
|
|
|
get_zmq_socket,
|
|
|
|
|
kill_process_tree,
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
|
|
|
|
2024-05-12 06:41:32 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class ReqState:
|
2024-08-21 19:24:36 -07:00
|
|
|
"""Store the state a request."""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
out_list: List
|
|
|
|
|
finished: bool
|
|
|
|
|
event: asyncio.Event
|
2024-12-08 12:27:13 -08:00
|
|
|
obj: Any
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-11-10 04:39:32 -08:00
|
|
|
# For metrics
|
|
|
|
|
created_time: float
|
|
|
|
|
first_token_time: Optional[float] = None
|
|
|
|
|
|
2024-12-22 06:25:57 -08:00
|
|
|
# For streaming output
|
|
|
|
|
last_output_offset: int = 0
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
class TokenizerManager:
|
2024-08-21 19:24:36 -07:00
|
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
server_args: ServerArgs,
|
|
|
|
|
port_args: PortArgs,
|
|
|
|
|
):
|
2024-11-03 03:52:38 -08:00
|
|
|
# Parse args
|
2024-03-11 20:06:52 +08:00
|
|
|
self.server_args = server_args
|
2024-11-10 04:39:32 -08:00
|
|
|
self.enable_metrics = server_args.enable_metrics
|
2024-03-11 20:06:52 +08:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Init inter-process communication
|
2024-01-08 04:37:50 +00:00
|
|
|
context = zmq.asyncio.Context(2)
|
2024-10-25 23:07:07 -07:00
|
|
|
self.recv_from_detokenizer = get_zmq_socket(
|
|
|
|
|
context, zmq.PULL, port_args.tokenizer_ipc_name
|
|
|
|
|
)
|
|
|
|
|
self.send_to_scheduler = get_zmq_socket(
|
|
|
|
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Read model args
|
2024-01-08 04:37:50 +00:00
|
|
|
self.model_path = server_args.model_path
|
2024-08-02 08:13:51 +08:00
|
|
|
self.served_model_name = server_args.served_model_name
|
2024-11-03 12:25:39 -08:00
|
|
|
self.model_config = ModelConfig(
|
|
|
|
|
server_args.model_path,
|
2024-05-14 07:57:00 +08:00
|
|
|
trust_remote_code=server_args.trust_remote_code,
|
2024-12-02 23:22:13 +08:00
|
|
|
revision=server_args.revision,
|
2024-11-03 12:25:39 -08:00
|
|
|
context_length=server_args.context_length,
|
|
|
|
|
model_override_args=server_args.json_model_override_args,
|
|
|
|
|
is_embedding=server_args.is_embedding,
|
2024-12-02 23:22:13 +08:00
|
|
|
dtype=server_args.dtype,
|
|
|
|
|
quantization=server_args.quantization,
|
2024-08-28 06:33:05 -07:00
|
|
|
)
|
2024-11-03 03:52:38 -08:00
|
|
|
|
2024-11-03 12:25:39 -08:00
|
|
|
self.is_generation = self.model_config.is_generation
|
|
|
|
|
self.context_len = self.model_config.context_len
|
2024-12-22 06:25:57 -08:00
|
|
|
self.image_token_id = self.model_config.image_token_id
|
2024-11-03 12:25:39 -08:00
|
|
|
|
2024-09-29 18:52:43 -07:00
|
|
|
# Create image processor placeholder
|
|
|
|
|
self.image_processor = get_dummy_image_processor()
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Create tokenizer
|
2024-08-10 03:14:13 +08:00
|
|
|
if server_args.skip_tokenizer_init:
|
|
|
|
|
self.tokenizer = self.processor = None
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
2024-11-03 12:25:39 -08:00
|
|
|
if self.model_config.is_multimodal:
|
2024-08-10 03:14:13 +08:00
|
|
|
self.processor = get_processor(
|
|
|
|
|
server_args.tokenizer_path,
|
|
|
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
|
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
|
|
|
)
|
|
|
|
|
self.tokenizer = self.processor.tokenizer
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
2024-08-28 06:33:05 -07:00
|
|
|
|
2024-09-29 18:52:43 -07:00
|
|
|
# We want to parallelize the image pre-processing so we create an executor for it
|
|
|
|
|
self.image_processor = get_image_processor(
|
2024-11-03 12:25:39 -08:00
|
|
|
self.model_config.hf_config, server_args, self.processor
|
2024-08-10 03:14:13 +08:00
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.tokenizer = get_tokenizer(
|
|
|
|
|
server_args.tokenizer_path,
|
|
|
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
|
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Store states
|
2024-01-08 04:37:50 +00:00
|
|
|
self.to_create_loop = True
|
2024-06-12 21:48:40 -07:00
|
|
|
self.rid_to_state: Dict[str, ReqState] = {}
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-12-22 06:25:57 -08:00
|
|
|
# The event to notify the weight sync is finished.
|
|
|
|
|
self.model_update_lock = RWLock()
|
|
|
|
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
|
|
|
|
None
|
|
|
|
|
)
|
|
|
|
|
self.asyncio_tasks = set()
|
2024-08-20 13:48:24 -07:00
|
|
|
|
2024-11-20 00:36:53 -08:00
|
|
|
# For session info
|
|
|
|
|
self.session_futures = {} # session_id -> asyncio event
|
|
|
|
|
|
2024-10-30 10:22:56 -07:00
|
|
|
# Others
|
|
|
|
|
self.gracefully_exit = False
|
2024-12-27 09:53:09 +08:00
|
|
|
self.init_weights_update_group_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
|
|
|
|
self.update_weights_from_distributed_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
2024-12-29 05:30:27 +08:00
|
|
|
self.update_weights_from_tensor_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
2024-12-27 09:53:09 +08:00
|
|
|
self.get_weights_by_name_communicator = _Communicator(
|
|
|
|
|
self.send_to_scheduler, server_args.dp_size
|
|
|
|
|
)
|
2024-10-30 10:22:56 -07:00
|
|
|
|
2024-11-10 04:39:32 -08:00
|
|
|
# Metrics
|
|
|
|
|
if self.enable_metrics:
|
|
|
|
|
self.metrics_collector = TokenizerMetricsCollector(
|
|
|
|
|
labels={
|
|
|
|
|
"model_name": self.server_args.served_model_name,
|
|
|
|
|
# TODO: Add lora name/path in the future,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
2024-08-08 16:31:19 -07:00
|
|
|
async def generate_request(
|
2024-08-28 06:33:05 -07:00
|
|
|
self,
|
2024-11-03 08:38:26 -08:00
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
2024-08-28 06:33:05 -07:00
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-08-08 16:31:19 -07:00
|
|
|
):
|
2024-11-10 04:39:32 -08:00
|
|
|
created_time = time.time()
|
|
|
|
|
|
2024-12-27 09:53:09 +08:00
|
|
|
self.auto_create_handle_loop()
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-10-17 09:23:29 -07:00
|
|
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
|
|
|
|
raise ValueError(
|
2024-11-03 03:52:38 -08:00
|
|
|
"This model does not appear to be an embedding model by default. "
|
|
|
|
|
"Please add `--is-embedding` when launching the server or try another model."
|
2024-10-17 09:23:29 -07:00
|
|
|
)
|
|
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
obj.normalize_batch_and_arguments()
|
2024-12-22 06:25:57 -08:00
|
|
|
|
|
|
|
|
if self.server_args.log_requests:
|
|
|
|
|
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
|
|
|
|
|
|
|
|
|
async with self.model_update_lock.reader_lock:
|
|
|
|
|
is_single = obj.is_single
|
|
|
|
|
if is_single:
|
|
|
|
|
tokenized_obj = await self._tokenize_one_request(obj)
|
|
|
|
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
|
|
|
|
async for response in self._wait_one_response(
|
|
|
|
|
obj, request, created_time
|
|
|
|
|
):
|
|
|
|
|
yield response
|
|
|
|
|
else:
|
|
|
|
|
async for response in self._handle_batch_request(
|
|
|
|
|
obj, request, created_time
|
|
|
|
|
):
|
|
|
|
|
yield response
|
2024-01-30 23:12:33 +09:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
async def _tokenize_one_request(
|
2024-08-08 16:31:19 -07:00
|
|
|
self,
|
2024-11-03 08:38:26 -08:00
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
2024-07-27 05:05:15 -07:00
|
|
|
):
|
2024-11-03 08:38:26 -08:00
|
|
|
"""Tokenize one request."""
|
|
|
|
|
# Tokenize
|
2024-11-25 19:35:04 -05:00
|
|
|
input_embeds = None
|
2024-11-03 08:38:26 -08:00
|
|
|
input_text = obj.text
|
2024-11-25 19:35:04 -05:00
|
|
|
if obj.input_embeds is not None:
|
|
|
|
|
if not self.server_args.disable_radix_cache:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"input_embeds is provided while disable_radix_cache is False. "
|
2024-12-22 06:25:57 -08:00
|
|
|
"Please add `--disable-radix-cache` when you launch the server "
|
2024-11-25 19:35:04 -05:00
|
|
|
"if you want to use input_embeds as inputs."
|
|
|
|
|
)
|
|
|
|
|
input_embeds = obj.input_embeds
|
|
|
|
|
input_ids = obj.input_ids
|
|
|
|
|
elif obj.input_ids is None:
|
2024-11-03 08:38:26 -08:00
|
|
|
input_ids = self.tokenizer.encode(input_text)
|
|
|
|
|
else:
|
|
|
|
|
input_ids = obj.input_ids
|
|
|
|
|
|
|
|
|
|
if self.is_generation:
|
2024-11-29 03:15:58 -08:00
|
|
|
# TODO: also support getting embeddings for multimodal models
|
2024-11-29 02:18:51 -08:00
|
|
|
image_inputs: Dict = await self.image_processor.process_images_async(
|
2024-11-03 08:38:26 -08:00
|
|
|
obj.image_data, input_text or input_ids, obj
|
2024-09-29 18:52:43 -07:00
|
|
|
)
|
2024-10-21 15:01:21 -07:00
|
|
|
if image_inputs and "input_ids" in image_inputs:
|
|
|
|
|
input_ids = image_inputs["input_ids"]
|
2024-11-03 08:38:26 -08:00
|
|
|
return_logprob = obj.return_logprob
|
|
|
|
|
logprob_start_len = obj.logprob_start_len
|
|
|
|
|
top_logprobs_num = obj.top_logprobs_num
|
2024-11-25 12:32:51 -08:00
|
|
|
session_id = obj.session[0] if obj.session else None
|
|
|
|
|
session_rid = obj.session[1] if obj.session else None
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-11-25 19:35:04 -05:00
|
|
|
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
2024-11-03 08:38:26 -08:00
|
|
|
raise ValueError(
|
|
|
|
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
|
|
|
|
f"model's context length ({self.context_len} tokens)."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Parse sampling parameters
|
|
|
|
|
sampling_params = SamplingParams(**obj.sampling_params)
|
|
|
|
|
sampling_params.normalize(self.tokenizer)
|
|
|
|
|
sampling_params.verify()
|
|
|
|
|
|
|
|
|
|
# Build return object
|
|
|
|
|
if isinstance(obj, GenerateReqInput):
|
2024-08-08 16:31:19 -07:00
|
|
|
tokenized_obj = TokenizedGenerateReqInput(
|
2024-11-03 08:38:26 -08:00
|
|
|
obj.rid,
|
2024-08-08 16:31:19 -07:00
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
2024-09-28 23:28:55 -07:00
|
|
|
image_inputs,
|
2024-08-08 16:31:19 -07:00
|
|
|
sampling_params,
|
|
|
|
|
return_logprob,
|
|
|
|
|
logprob_start_len,
|
|
|
|
|
top_logprobs_num,
|
|
|
|
|
obj.stream,
|
2024-11-25 19:35:04 -05:00
|
|
|
lora_path=obj.lora_path,
|
|
|
|
|
input_embeds=input_embeds,
|
2024-11-20 00:36:53 -08:00
|
|
|
session_id=session_id,
|
|
|
|
|
session_rid=session_rid,
|
2024-08-08 16:31:19 -07:00
|
|
|
)
|
2024-09-27 23:32:11 -07:00
|
|
|
elif isinstance(obj, EmbeddingReqInput):
|
2024-08-08 16:31:19 -07:00
|
|
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
2024-11-03 08:38:26 -08:00
|
|
|
obj.rid,
|
2024-09-27 23:32:11 -07:00
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
|
|
|
|
sampling_params,
|
|
|
|
|
)
|
2024-10-01 10:25:32 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
return tokenized_obj
|
2024-10-01 10:25:32 -07:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
async def _wait_one_response(
|
2024-10-01 10:25:32 -07:00
|
|
|
self,
|
2024-11-03 08:38:26 -08:00
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
2024-10-01 10:25:32 -07:00
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-11-10 04:39:32 -08:00
|
|
|
created_time: Optional[float] = None,
|
2024-10-01 10:25:32 -07:00
|
|
|
):
|
2024-11-03 08:38:26 -08:00
|
|
|
"""Wait for the response of one request."""
|
2024-07-20 14:10:01 +08:00
|
|
|
event = asyncio.Event()
|
2024-12-08 12:27:13 -08:00
|
|
|
state = ReqState([], False, event, obj, created_time=created_time)
|
2024-11-03 08:38:26 -08:00
|
|
|
self.rid_to_state[obj.rid] = state
|
2024-08-20 08:06:55 -07:00
|
|
|
|
2024-07-20 14:10:01 +08:00
|
|
|
while True:
|
|
|
|
|
try:
|
2024-08-28 06:33:05 -07:00
|
|
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
2024-07-20 14:10:01 +08:00
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
if request is not None and await request.is_disconnected():
|
2024-11-03 08:38:26 -08:00
|
|
|
self.abort_request(obj.rid)
|
|
|
|
|
raise ValueError(f"Abort request {obj.rid}")
|
2024-07-20 14:10:01 +08:00
|
|
|
continue
|
|
|
|
|
|
2024-12-08 12:27:13 -08:00
|
|
|
out = state.out_list[-1]
|
2024-07-20 14:10:01 +08:00
|
|
|
|
|
|
|
|
state.out_list = []
|
|
|
|
|
if state.finished:
|
2024-11-03 03:52:38 -08:00
|
|
|
if self.server_args.log_requests:
|
2024-12-22 06:25:57 -08:00
|
|
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
|
|
|
|
logger.info(msg)
|
2024-11-03 08:38:26 -08:00
|
|
|
del self.rid_to_state[obj.rid]
|
2024-07-20 14:10:01 +08:00
|
|
|
yield out
|
|
|
|
|
break
|
|
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
state.event.clear()
|
2024-12-08 12:27:13 -08:00
|
|
|
|
|
|
|
|
if obj.stream:
|
|
|
|
|
yield out
|
|
|
|
|
else:
|
|
|
|
|
if request is not None and await request.is_disconnected():
|
|
|
|
|
self.abort_request(obj.rid)
|
|
|
|
|
raise ValueError(f"Abort request {obj.rid}")
|
2024-07-20 14:10:01 +08:00
|
|
|
|
2024-11-03 08:38:26 -08:00
|
|
|
async def _handle_batch_request(
|
|
|
|
|
self,
|
|
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-11-10 04:39:32 -08:00
|
|
|
created_time: Optional[float] = None,
|
2024-11-03 08:38:26 -08:00
|
|
|
):
|
|
|
|
|
batch_size = obj.batch_size
|
|
|
|
|
|
|
|
|
|
generators = []
|
|
|
|
|
rids = []
|
|
|
|
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
|
|
|
|
# Send all requests
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
tmp_obj = obj[i]
|
|
|
|
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
|
|
|
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
2024-11-10 04:39:32 -08:00
|
|
|
generators.append(
|
|
|
|
|
self._wait_one_response(tmp_obj, request, created_time)
|
|
|
|
|
)
|
2024-11-03 08:38:26 -08:00
|
|
|
rids.append(tmp_obj.rid)
|
|
|
|
|
else:
|
|
|
|
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
2024-12-02 02:27:36 -08:00
|
|
|
if batch_size > 128:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
|
|
|
|
"The performance might be better if you just duplicate the requests n times or use "
|
|
|
|
|
"many threads to send them one by one with parallel sampling (n > 1)."
|
|
|
|
|
)
|
2024-11-03 08:38:26 -08:00
|
|
|
|
|
|
|
|
# Tokenize all requests
|
|
|
|
|
objs = [obj[i] for i in range(batch_size)]
|
2024-11-07 15:42:47 -08:00
|
|
|
tokenized_objs = await asyncio.gather(
|
|
|
|
|
*(self._tokenize_one_request(obj) for obj in objs)
|
|
|
|
|
)
|
2024-11-03 08:38:26 -08:00
|
|
|
|
|
|
|
|
# Cache the common prefix for parallel sampling
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
tmp_obj = copy.copy(objs[i])
|
|
|
|
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
|
|
|
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
|
|
|
|
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
|
|
|
|
tokenized_obj.sampling_params.max_new_tokens = 0
|
|
|
|
|
tokenized_obj.stream = False
|
|
|
|
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
2024-11-10 04:39:32 -08:00
|
|
|
await self._wait_one_response(
|
|
|
|
|
tmp_obj, request, created_time
|
|
|
|
|
).__anext__()
|
2024-11-03 08:38:26 -08:00
|
|
|
|
|
|
|
|
# Expand requests, assign new rids for them, and send them
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
for _ in range(obj.parallel_sample_num):
|
|
|
|
|
tmp_obj = copy.copy(objs[i])
|
|
|
|
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
|
|
|
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
|
|
|
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
2024-11-10 04:39:32 -08:00
|
|
|
generators.append(
|
|
|
|
|
self._wait_one_response(tmp_obj, request, created_time)
|
|
|
|
|
)
|
2024-11-03 08:38:26 -08:00
|
|
|
rids.append(tmp_obj.rid)
|
|
|
|
|
|
|
|
|
|
# Wait for all requests
|
|
|
|
|
is_stream = hasattr(obj, "stream") and obj.stream
|
|
|
|
|
if not is_stream:
|
|
|
|
|
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
|
|
|
|
|
yield outputs
|
|
|
|
|
else:
|
|
|
|
|
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
|
|
|
|
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
|
|
|
|
while task_map:
|
2024-11-07 15:42:47 -08:00
|
|
|
done, _ = await asyncio.wait(
|
|
|
|
|
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
|
|
|
|
)
|
2024-11-03 08:38:26 -08:00
|
|
|
|
|
|
|
|
for task in done:
|
|
|
|
|
gen = task_map.pop(task)
|
|
|
|
|
try:
|
|
|
|
|
result = task.result()
|
|
|
|
|
result["index"] = rid_to_index[result["meta_info"]["id"]]
|
|
|
|
|
yield result
|
|
|
|
|
new_task = asyncio.create_task(gen.__anext__())
|
|
|
|
|
task_map[new_task] = gen
|
|
|
|
|
except StopAsyncIteration:
|
|
|
|
|
pass
|
|
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
def flush_cache(self):
|
|
|
|
|
req = FlushCacheReq()
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
2024-01-26 13:32:59 +08:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
def abort_request(self, rid: str):
|
|
|
|
|
if rid not in self.rid_to_state:
|
|
|
|
|
return
|
|
|
|
|
del self.rid_to_state[rid]
|
|
|
|
|
req = AbortReq(rid)
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
2024-08-25 14:46:34 -07:00
|
|
|
|
2024-10-11 17:34:25 +08:00
|
|
|
def start_profile(self):
|
|
|
|
|
req = ProfileReq.START_PROFILE
|
|
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
|
|
|
|
|
|
|
|
|
def stop_profile(self):
|
|
|
|
|
req = ProfileReq.STOP_PROFILE
|
|
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
|
|
|
|
|
2024-11-29 17:17:00 -08:00
|
|
|
async def update_weights_from_disk(
|
|
|
|
|
self,
|
|
|
|
|
obj: UpdateWeightFromDiskReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-12-22 06:25:57 -08:00
|
|
|
) -> Tuple[bool, str]:
|
2024-12-27 09:53:09 +08:00
|
|
|
self.auto_create_handle_loop()
|
2024-08-20 13:48:24 -07:00
|
|
|
|
|
|
|
|
# default the load format to the server_args
|
|
|
|
|
if obj.load_format is None:
|
|
|
|
|
obj.load_format = self.server_args.load_format
|
2024-12-22 06:25:57 -08:00
|
|
|
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
2024-08-20 13:48:24 -07:00
|
|
|
|
2024-12-22 06:25:57 -08:00
|
|
|
if True:
|
|
|
|
|
# Hold the lock if it is not async. This means that weight sync
|
|
|
|
|
# cannot run while requests are in progress.
|
|
|
|
|
async with self.model_update_lock.writer_lock:
|
|
|
|
|
return await self._wait_for_model_update_from_disk(obj)
|
2024-10-28 12:02:23 -07:00
|
|
|
|
2024-12-22 06:25:57 -08:00
|
|
|
async def _wait_for_model_update_from_disk(
|
|
|
|
|
self, obj: UpdateWeightFromDiskReqInput
|
2024-12-27 09:53:09 +08:00
|
|
|
) -> Tuple[bool, str]:
|
2024-12-22 06:25:57 -08:00
|
|
|
self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
|
self.model_update_result = asyncio.Future()
|
|
|
|
|
if self.server_args.dp_size == 1:
|
|
|
|
|
result = await self.model_update_result
|
|
|
|
|
if result.success:
|
|
|
|
|
self.served_model_name = obj.model_path
|
|
|
|
|
self.server_args.model_path = obj.model_path
|
|
|
|
|
self.server_args.load_format = obj.load_format
|
|
|
|
|
self.model_path = obj.model_path
|
|
|
|
|
return result.success, result.message
|
|
|
|
|
else: # self.server_args.dp_size > 1
|
|
|
|
|
self.model_update_tmp = []
|
|
|
|
|
result = await self.model_update_result
|
|
|
|
|
|
|
|
|
|
all_success = all([r.success for r in result])
|
|
|
|
|
if all_success is True:
|
|
|
|
|
self.server_args.model_path = obj.model_path
|
|
|
|
|
self.server_args.load_format = obj.load_format
|
|
|
|
|
self.model_path = obj.model_path
|
|
|
|
|
all_message = [r.message for r in result]
|
|
|
|
|
all_message = " | ".join(all_message)
|
|
|
|
|
return all_success, all_message
|
2024-08-20 13:48:24 -07:00
|
|
|
|
2024-12-01 23:23:18 -08:00
|
|
|
async def init_weights_update_group(
|
|
|
|
|
self,
|
|
|
|
|
obj: InitWeightsUpdateGroupReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-12-22 06:25:57 -08:00
|
|
|
) -> Tuple[bool, str]:
|
2024-12-27 09:53:09 +08:00
|
|
|
self.auto_create_handle_loop()
|
2024-12-01 23:23:18 -08:00
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be 1 for init parameter update group"
|
2024-12-27 09:53:09 +08:00
|
|
|
result = (await self.init_weights_update_group_communicator(obj))[0]
|
2024-12-01 23:23:18 -08:00
|
|
|
return result.success, result.message
|
|
|
|
|
|
|
|
|
|
async def update_weights_from_distributed(
|
|
|
|
|
self,
|
|
|
|
|
obj: UpdateWeightsFromDistributedReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-12-22 06:25:57 -08:00
|
|
|
) -> Tuple[bool, str]:
|
2024-12-27 09:53:09 +08:00
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be for update weights from distributed"
|
2024-12-01 23:23:18 -08:00
|
|
|
|
2024-12-22 06:25:57 -08:00
|
|
|
# This means that weight sync
|
|
|
|
|
# cannot run while requests are in progress.
|
|
|
|
|
async with self.model_update_lock.writer_lock:
|
2024-12-27 09:53:09 +08:00
|
|
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
2024-12-22 06:25:57 -08:00
|
|
|
return result.success, result.message
|
2024-12-01 23:23:18 -08:00
|
|
|
|
2024-12-29 05:30:27 +08:00
|
|
|
async def update_weights_from_tensor(
|
|
|
|
|
self,
|
|
|
|
|
obj: UpdateWeightsFromTensorReqInput,
|
|
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
) -> Tuple[bool, str]:
|
|
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be for update weights from distributed"
|
|
|
|
|
|
|
|
|
|
# This means that weight sync
|
|
|
|
|
# cannot run while requests are in progress.
|
|
|
|
|
async with self.model_update_lock.writer_lock:
|
|
|
|
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
|
|
|
|
return result.success, result.message
|
|
|
|
|
|
2024-11-29 23:36:38 -08:00
|
|
|
async def get_weights_by_name(
|
|
|
|
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
|
|
|
|
):
|
2024-12-27 09:53:09 +08:00
|
|
|
self.auto_create_handle_loop()
|
|
|
|
|
results = await self.get_weights_by_name_communicator(obj)
|
|
|
|
|
all_parameters = [r.parameter for r in results]
|
2024-11-29 23:36:38 -08:00
|
|
|
if self.server_args.dp_size == 1:
|
2024-12-27 09:53:09 +08:00
|
|
|
return all_parameters[0]
|
2024-11-29 23:36:38 -08:00
|
|
|
else:
|
|
|
|
|
return all_parameters
|
|
|
|
|
|
2024-11-20 00:36:53 -08:00
|
|
|
async def open_session(
|
|
|
|
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
|
|
|
|
):
|
2024-12-27 09:53:09 +08:00
|
|
|
self.auto_create_handle_loop()
|
2024-11-20 00:36:53 -08:00
|
|
|
|
|
|
|
|
session_id = uuid.uuid4().hex
|
|
|
|
|
obj.session_id = session_id
|
|
|
|
|
self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
|
self.session_futures[session_id] = asyncio.Future()
|
|
|
|
|
session_id = await self.session_futures[session_id]
|
|
|
|
|
del self.session_futures[session_id]
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
|
|
|
async def close_session(
|
|
|
|
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
|
|
|
|
):
|
|
|
|
|
assert not self.to_create_loop, "close session should not be the first request"
|
|
|
|
|
await self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
|
|
2024-06-08 02:06:52 -07:00
|
|
|
def create_abort_task(self, obj: GenerateReqInput):
|
2024-05-20 18:41:21 -07:00
|
|
|
# Abort the request if the client is disconnected.
|
|
|
|
|
async def abort_request():
|
2024-10-25 18:51:59 -07:00
|
|
|
await asyncio.sleep(1)
|
2024-05-20 18:41:21 -07:00
|
|
|
if obj.is_single:
|
|
|
|
|
self.abort_request(obj.rid)
|
|
|
|
|
else:
|
2024-08-13 20:47:22 +08:00
|
|
|
for rid in obj.rid:
|
2024-05-20 18:41:21 -07:00
|
|
|
self.abort_request(rid)
|
|
|
|
|
|
|
|
|
|
background_tasks = BackgroundTasks()
|
|
|
|
|
background_tasks.add_task(abort_request)
|
|
|
|
|
return background_tasks
|
|
|
|
|
|
2024-12-27 09:53:09 +08:00
|
|
|
def auto_create_handle_loop(self):
|
2024-08-24 08:02:23 -07:00
|
|
|
if not self.to_create_loop:
|
|
|
|
|
return
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
self.to_create_loop = False
|
|
|
|
|
loop = asyncio.get_event_loop()
|
2024-12-22 06:25:57 -08:00
|
|
|
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-10-30 10:22:56 -07:00
|
|
|
signal_handler = SignalHandler(self)
|
|
|
|
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
2024-12-22 06:25:57 -08:00
|
|
|
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
|
2024-10-30 10:22:56 -07:00
|
|
|
|
|
|
|
|
async def sigterm_watchdog(self):
|
|
|
|
|
while not self.gracefully_exit:
|
2024-12-08 21:17:31 -08:00
|
|
|
await asyncio.sleep(5)
|
2024-10-30 10:22:56 -07:00
|
|
|
|
|
|
|
|
# drain requests
|
|
|
|
|
while True:
|
|
|
|
|
remain_num_req = len(self.rid_to_state)
|
|
|
|
|
logger.info(
|
2024-11-03 03:52:38 -08:00
|
|
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
2024-10-30 10:22:56 -07:00
|
|
|
)
|
|
|
|
|
if remain_num_req > 0:
|
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
2024-11-28 00:22:39 -08:00
|
|
|
kill_process_tree(os.getpid(), include_parent=True)
|
2024-11-08 02:19:41 -08:00
|
|
|
sys.exit(0)
|
2024-10-30 10:22:56 -07:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
async def handle_loop(self):
|
2024-08-25 14:46:34 -07:00
|
|
|
"""The event loop that handles requests"""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
while True:
|
2024-08-20 13:48:24 -07:00
|
|
|
recv_obj: Union[
|
2024-11-29 17:17:00 -08:00
|
|
|
BatchStrOut,
|
|
|
|
|
BatchEmbeddingOut,
|
|
|
|
|
BatchTokenIDOut,
|
|
|
|
|
UpdateWeightFromDiskReqOutput,
|
2024-12-01 23:23:18 -08:00
|
|
|
UpdateWeightsFromDistributedReqOutput,
|
2024-11-29 23:36:38 -08:00
|
|
|
GetWeightsByNameReqOutput,
|
2024-12-01 23:23:18 -08:00
|
|
|
InitWeightsUpdateGroupReqOutput,
|
2024-08-20 13:48:24 -07:00
|
|
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
|
|
|
|
|
2024-12-02 02:27:36 -08:00
|
|
|
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
|
|
|
|
for i, rid in enumerate(recv_obj.rids):
|
|
|
|
|
state = self.rid_to_state.get(rid, None)
|
|
|
|
|
if state is None:
|
|
|
|
|
continue
|
|
|
|
|
|
2024-12-08 12:27:13 -08:00
|
|
|
meta_info = {
|
|
|
|
|
"id": rid,
|
|
|
|
|
"finish_reason": recv_obj.finished_reasons[i],
|
|
|
|
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if getattr(state.obj, "return_logprob", False):
|
|
|
|
|
self.convert_logprob_style(
|
|
|
|
|
meta_info,
|
|
|
|
|
state.obj.top_logprobs_num,
|
|
|
|
|
state.obj.return_text_in_logprobs,
|
|
|
|
|
recv_obj,
|
|
|
|
|
i,
|
|
|
|
|
)
|
|
|
|
|
|
2024-12-09 03:05:59 -08:00
|
|
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
|
|
|
|
meta_info.update(
|
|
|
|
|
{
|
|
|
|
|
"completion_tokens": recv_obj.completion_tokens[i],
|
|
|
|
|
"cached_tokens": recv_obj.cached_tokens[i],
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
2024-12-02 02:27:36 -08:00
|
|
|
if isinstance(recv_obj, BatchStrOut):
|
|
|
|
|
out_dict = {
|
|
|
|
|
"text": recv_obj.output_strs[i],
|
2024-12-09 03:05:59 -08:00
|
|
|
"meta_info": meta_info,
|
2024-12-02 02:27:36 -08:00
|
|
|
}
|
|
|
|
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
|
|
|
|
out_dict = {
|
|
|
|
|
"token_ids": recv_obj.output_ids[i],
|
2024-12-09 03:05:59 -08:00
|
|
|
"meta_info": meta_info,
|
2024-12-02 02:27:36 -08:00
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
|
|
|
|
out_dict = {
|
|
|
|
|
"embedding": recv_obj.embeddings[i],
|
2024-12-08 12:27:13 -08:00
|
|
|
"meta_info": meta_info,
|
2024-12-02 02:27:36 -08:00
|
|
|
}
|
|
|
|
|
state.out_list.append(out_dict)
|
2024-12-08 12:27:13 -08:00
|
|
|
state.finished = recv_obj.finished_reasons[i] is not None
|
2024-12-02 02:27:36 -08:00
|
|
|
state.event.set()
|
|
|
|
|
|
|
|
|
|
if self.enable_metrics:
|
2024-12-08 12:27:13 -08:00
|
|
|
completion_tokens = (
|
|
|
|
|
recv_obj.completion_tokens[i]
|
|
|
|
|
if recv_obj.completion_tokens
|
|
|
|
|
else 0
|
|
|
|
|
)
|
2024-12-02 02:27:36 -08:00
|
|
|
|
|
|
|
|
if state.first_token_time is None:
|
|
|
|
|
state.first_token_time = time.time()
|
|
|
|
|
self.metrics_collector.observe_time_to_first_token(
|
|
|
|
|
state.first_token_time - state.created_time
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if completion_tokens >= 2:
|
|
|
|
|
self.metrics_collector.observe_time_per_output_token(
|
|
|
|
|
(time.time() - state.first_token_time)
|
|
|
|
|
/ (completion_tokens - 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if state.finished:
|
|
|
|
|
self.metrics_collector.inc_prompt_tokens(
|
2024-12-08 12:27:13 -08:00
|
|
|
recv_obj.prompt_tokens[i]
|
2024-12-02 02:27:36 -08:00
|
|
|
)
|
|
|
|
|
self.metrics_collector.inc_generation_tokens(
|
|
|
|
|
completion_tokens
|
|
|
|
|
)
|
|
|
|
|
self.metrics_collector.observe_e2e_request_latency(
|
|
|
|
|
time.time() - state.created_time
|
|
|
|
|
)
|
|
|
|
|
if completion_tokens >= 1:
|
|
|
|
|
self.metrics_collector.observe_time_per_output_token(
|
|
|
|
|
(time.time() - state.created_time)
|
|
|
|
|
/ completion_tokens
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
|
|
|
|
self.session_futures[recv_obj.session_id].set_result(
|
|
|
|
|
recv_obj.session_id
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
2024-10-28 12:02:23 -07:00
|
|
|
if self.server_args.dp_size == 1:
|
|
|
|
|
self.model_update_result.set_result(recv_obj)
|
2024-11-07 15:42:47 -08:00
|
|
|
else: # self.server_args.dp_size > 1
|
2024-10-28 12:02:23 -07:00
|
|
|
self.model_update_tmp.append(recv_obj)
|
|
|
|
|
# set future if the all results are recevied
|
|
|
|
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
|
|
|
|
self.model_update_result.set_result(self.model_update_tmp)
|
2024-12-02 02:27:36 -08:00
|
|
|
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be 1 for init parameter update group"
|
2024-12-27 09:53:09 +08:00
|
|
|
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
2024-12-01 23:23:18 -08:00
|
|
|
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be 1 for update weights from distributed"
|
2024-12-27 09:53:09 +08:00
|
|
|
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
2024-12-29 05:30:27 +08:00
|
|
|
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
|
|
|
|
assert (
|
|
|
|
|
self.server_args.dp_size == 1
|
|
|
|
|
), "dp_size must be 1 for update weights from distributed"
|
|
|
|
|
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
2024-11-29 23:36:38 -08:00
|
|
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
2024-12-27 09:53:09 +08:00
|
|
|
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
2024-12-02 02:27:36 -08:00
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Invalid object: {recv_obj=}")
|
2024-11-10 04:39:32 -08:00
|
|
|
|
2024-05-14 22:40:46 +08:00
|
|
|
def convert_logprob_style(
|
2024-07-27 05:05:15 -07:00
|
|
|
self,
|
2024-12-08 12:27:13 -08:00
|
|
|
meta_info: dict,
|
2024-07-27 05:05:15 -07:00
|
|
|
top_logprobs_num: int,
|
|
|
|
|
return_text_in_logprobs: bool,
|
2024-12-08 12:27:13 -08:00
|
|
|
recv_obj: BatchStrOut,
|
|
|
|
|
recv_obj_index: int,
|
2024-05-14 22:40:46 +08:00
|
|
|
):
|
2024-12-08 12:27:13 -08:00
|
|
|
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
|
|
|
recv_obj.input_token_logprobs_val[recv_obj_index],
|
|
|
|
|
recv_obj.input_token_logprobs_idx[recv_obj_index],
|
|
|
|
|
return_text_in_logprobs,
|
|
|
|
|
)
|
|
|
|
|
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
|
|
|
recv_obj.output_token_logprobs_val[recv_obj_index],
|
|
|
|
|
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
|
|
|
|
return_text_in_logprobs,
|
|
|
|
|
)
|
|
|
|
|
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
|
|
|
|
recv_obj_index
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if top_logprobs_num > 0:
|
|
|
|
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
|
|
|
recv_obj.input_top_logprobs_val[recv_obj_index],
|
|
|
|
|
recv_obj.input_top_logprobs_idx[recv_obj_index],
|
|
|
|
|
return_text_in_logprobs,
|
2024-05-12 04:54:07 -07:00
|
|
|
)
|
2024-12-08 12:27:13 -08:00
|
|
|
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
|
|
|
recv_obj.output_top_logprobs_val[recv_obj_index],
|
|
|
|
|
recv_obj.output_top_logprobs_idx[recv_obj_index],
|
|
|
|
|
return_text_in_logprobs,
|
2024-05-12 04:54:07 -07:00
|
|
|
)
|
2024-07-09 15:35:39 +08:00
|
|
|
|
2024-07-28 05:22:14 -07:00
|
|
|
def detokenize_logprob_tokens(
|
2024-12-08 12:27:13 -08:00
|
|
|
self,
|
|
|
|
|
token_logprobs_val: List[float],
|
|
|
|
|
token_logprobs_idx: List[int],
|
|
|
|
|
decode_to_text: bool,
|
2024-07-28 05:22:14 -07:00
|
|
|
):
|
2024-05-12 04:54:07 -07:00
|
|
|
if not decode_to_text:
|
2024-12-08 12:27:13 -08:00
|
|
|
return [
|
|
|
|
|
(logprob, token_id, None)
|
|
|
|
|
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
assert self.tokenizer is not None
|
|
|
|
|
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
|
|
|
|
|
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
|
2024-05-12 04:54:07 -07:00
|
|
|
|
2024-12-08 12:27:13 -08:00
|
|
|
def detokenize_top_logprobs_tokens(
|
|
|
|
|
self,
|
|
|
|
|
token_logprobs_val: List[float],
|
|
|
|
|
token_logprobs_idx: List[int],
|
|
|
|
|
decode_to_text: bool,
|
|
|
|
|
):
|
2024-07-28 05:22:14 -07:00
|
|
|
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
|
|
|
|
# We should batch all top-k tokens in all positions.
|
2024-12-08 12:27:13 -08:00
|
|
|
ret = []
|
|
|
|
|
for i in range(len(token_logprobs_val)):
|
|
|
|
|
if token_logprobs_val[i]:
|
|
|
|
|
ret.append(
|
|
|
|
|
self.detokenize_logprob_tokens(
|
|
|
|
|
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
|
|
|
|
|
)
|
2024-07-28 05:22:14 -07:00
|
|
|
)
|
2024-12-08 12:27:13 -08:00
|
|
|
else:
|
|
|
|
|
ret.append(None)
|
|
|
|
|
return ret
|
2024-10-30 10:22:56 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class SignalHandler:
|
|
|
|
|
def __init__(self, tokenizer_manager):
|
|
|
|
|
self.tokenizer_manager = tokenizer_manager
|
|
|
|
|
|
|
|
|
|
def signal_handler(self, signum=None, frame=None):
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
|
|
|
|
)
|
|
|
|
|
self.tokenizer_manager.gracefully_exit = True
|
2024-12-27 09:53:09 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _Communicator(Generic[T]):
|
|
|
|
|
def __init__(self, sender, fan_out: int):
|
|
|
|
|
self._sender = sender
|
|
|
|
|
self._fan_out = fan_out
|
|
|
|
|
self._result_future: Optional[asyncio.Future] = None
|
|
|
|
|
self._result_values: Optional[List[T]] = None
|
|
|
|
|
|
|
|
|
|
async def __call__(self, obj):
|
|
|
|
|
self._sender.send_pyobj(obj)
|
|
|
|
|
self._result_future = asyncio.Future()
|
|
|
|
|
self._result_values = []
|
|
|
|
|
await self._result_future
|
|
|
|
|
result_values = self._result_values
|
|
|
|
|
self._result_future = self._result_values = None
|
|
|
|
|
return result_values
|
|
|
|
|
|
|
|
|
|
def handle_recv(self, recv_obj: T):
|
|
|
|
|
self._result_values.append(recv_obj)
|
|
|
|
|
if len(self._result_values) == self._fan_out:
|
|
|
|
|
self._result_future.set_result(None)
|