Files
sglang/python/sglang/srt/managers/tokenizer_manager.py

847 lines
32 KiB
Python
Raw Normal View History

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-06-08 02:06:52 -07:00
"""TokenizerManager is a process that tokenizes the text."""
import asyncio
2024-11-03 08:38:26 -08:00
import copy
import dataclasses
import logging
import os
import signal
import sys
import time
2024-11-20 00:36:53 -08:00
import uuid
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi
import uvloop
import zmq
import zmq.asyncio
2024-05-20 18:41:21 -07:00
from fastapi import BackgroundTasks
2024-04-22 22:38:09 +08:00
from sglang.srt.aio_rwlock import RWLock
2024-11-03 12:25:39 -08:00
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
2024-09-29 18:52:43 -07:00
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
2024-11-20 00:36:53 -08:00
CloseSessionReqInput,
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
2024-11-20 00:36:53 -08:00
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
2024-11-29 17:17:00 -08:00
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.metrics.collector import TokenizerMetricsCollector
2024-08-21 16:48:24 -07:00
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
dataclass_to_string_truncated,
get_zmq_socket,
kill_process_tree,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class ReqState:
"""Store the state a request."""
out_list: List
finished: bool
event: asyncio.Event
2024-12-08 12:27:13 -08:00
obj: Any
# For metrics
created_time: float
first_token_time: Optional[float] = None
# For streaming output
last_output_offset: int = 0
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
2024-11-03 03:52:38 -08:00
# Parse args
2024-03-11 20:06:52 +08:00
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
2024-03-11 20:06:52 +08:00
# Init inter-process communication
context = zmq.asyncio.Context(2)
2024-10-25 23:07:07 -07:00
self.recv_from_detokenizer = get_zmq_socket(
context, zmq.PULL, port_args.tokenizer_ipc_name
)
self.send_to_scheduler = get_zmq_socket(
context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
2024-11-03 12:25:39 -08:00
self.model_config = ModelConfig(
server_args.model_path,
2024-05-14 07:57:00 +08:00
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
2024-11-03 12:25:39 -08:00
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
2024-11-03 03:52:38 -08:00
2024-11-03 12:25:39 -08:00
self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
2024-11-03 12:25:39 -08:00
2024-09-29 18:52:43 -07:00
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
# Create tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
2024-11-03 12:25:39 -08:00
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2024-09-29 18:52:43 -07:00
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
2024-11-03 12:25:39 -08:00
self.model_config.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
# Store states
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self.asyncio_tasks = set()
2024-11-20 00:36:53 -08:00
# For session info
self.session_futures = {} # session_id -> asyncio event
# Others
self.gracefully_exit = False
self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Metrics
if self.enable_metrics:
self.metrics_collector = TokenizerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
async def generate_request(
self,
2024-11-03 08:38:26 -08:00
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
2024-11-03 03:52:38 -08:00
"This model does not appear to be an embedding model by default. "
"Please add `--is-embedding` when launching the server or try another model."
)
2024-11-03 08:38:26 -08:00
obj.normalize_batch_and_arguments()
if self.server_args.log_requests:
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
tokenized_obj = await self._tokenize_one_request(obj)
self.send_to_scheduler.send_pyobj(tokenized_obj)
async for response in self._wait_one_response(
obj, request, created_time
):
yield response
else:
async for response in self._handle_batch_request(
obj, request, created_time
):
yield response
2024-11-03 08:38:26 -08:00
async def _tokenize_one_request(
self,
2024-11-03 08:38:26 -08:00
obj: Union[GenerateReqInput, EmbeddingReqInput],
):
2024-11-03 08:38:26 -08:00
"""Tokenize one request."""
# Tokenize
2024-11-25 19:35:04 -05:00
input_embeds = None
2024-11-03 08:38:26 -08:00
input_text = obj.text
2024-11-25 19:35:04 -05:00
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cache` when you launch the server "
2024-11-25 19:35:04 -05:00
"if you want to use input_embeds as inputs."
)
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is None:
2024-11-03 08:38:26 -08:00
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids
if self.is_generation:
# TODO: also support getting embeddings for multimodal models
2024-11-29 02:18:51 -08:00
image_inputs: Dict = await self.image_processor.process_images_async(
2024-11-03 08:38:26 -08:00
obj.image_data, input_text or input_ids, obj
2024-09-29 18:52:43 -07:00
)
2024-10-21 15:01:21 -07:00
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
2024-11-03 08:38:26 -08:00
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_id = obj.session[0] if obj.session else None
session_rid = obj.session[1] if obj.session else None
2024-11-25 19:35:04 -05:00
if obj.input_ids is not None and len(input_ids) >= self.context_len:
2024-11-03 08:38:26 -08:00
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
# Build return object
if isinstance(obj, GenerateReqInput):
tokenized_obj = TokenizedGenerateReqInput(
2024-11-03 08:38:26 -08:00
obj.rid,
input_text,
input_ids,
2024-09-28 23:28:55 -07:00
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
obj.stream,
2024-11-25 19:35:04 -05:00
lora_path=obj.lora_path,
input_embeds=input_embeds,
2024-11-20 00:36:53 -08:00
session_id=session_id,
session_rid=session_rid,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
2024-11-03 08:38:26 -08:00
obj.rid,
input_text,
input_ids,
sampling_params,
)
2024-11-03 08:38:26 -08:00
return tokenized_obj
2024-11-03 08:38:26 -08:00
async def _wait_one_response(
self,
2024-11-03 08:38:26 -08:00
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
):
2024-11-03 08:38:26 -08:00
"""Wait for the response of one request."""
event = asyncio.Event()
2024-12-08 12:27:13 -08:00
state = ReqState([], False, event, obj, created_time=created_time)
2024-11-03 08:38:26 -08:00
self.rid_to_state[obj.rid] = state
while True:
try:
await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
2024-11-03 08:38:26 -08:00
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
continue
2024-12-08 12:27:13 -08:00
out = state.out_list[-1]
state.out_list = []
if state.finished:
2024-11-03 03:52:38 -08:00
if self.server_args.log_requests:
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(msg)
2024-11-03 08:38:26 -08:00
del self.rid_to_state[obj.rid]
yield out
break
state.event.clear()
2024-12-08 12:27:13 -08:00
if obj.stream:
yield out
else:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
2024-11-03 08:38:26 -08:00
async def _handle_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
created_time: Optional[float] = None,
2024-11-03 08:38:26 -08:00
):
batch_size = obj.batch_size
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) == 1:
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
self.send_to_scheduler.send_pyobj(tokenized_obj)
generators.append(
self._wait_one_response(tmp_obj, request, created_time)
)
2024-11-03 08:38:26 -08:00
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
2024-12-02 02:27:36 -08:00
if batch_size > 128:
logger.warning(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)
2024-11-03 08:38:26 -08:00
# Tokenize all requests
objs = [obj[i] for i in range(batch_size)]
2024-11-07 15:42:47 -08:00
tokenized_objs = await asyncio.gather(
*(self._tokenize_one_request(obj) for obj in objs)
)
2024-11-03 08:38:26 -08:00
# Cache the common prefix for parallel sampling
for i in range(batch_size):
tmp_obj = copy.copy(objs[i])
tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid()
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
tokenized_obj.sampling_params.max_new_tokens = 0
tokenized_obj.stream = False
self.send_to_scheduler.send_pyobj(tokenized_obj)
await self._wait_one_response(
tmp_obj, request, created_time
).__anext__()
2024-11-03 08:38:26 -08:00
# Expand requests, assign new rids for them, and send them
for i in range(batch_size):
for _ in range(obj.parallel_sample_num):
tmp_obj = copy.copy(objs[i])
tokenized_obj = copy.copy(tokenized_objs[i])
tokenized_obj.rid = tmp_obj.regenerate_rid()
self.send_to_scheduler.send_pyobj(tokenized_obj)
generators.append(
self._wait_one_response(tmp_obj, request, created_time)
)
2024-11-03 08:38:26 -08:00
rids.append(tmp_obj.rid)
# Wait for all requests
is_stream = hasattr(obj, "stream") and obj.stream
if not is_stream:
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
yield outputs
else:
rid_to_index = {rid: i for i, rid in enumerate(rids)}
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
while task_map:
2024-11-07 15:42:47 -08:00
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
2024-11-03 08:38:26 -08:00
for task in done:
gen = task_map.pop(task)
try:
result = task.result()
result["index"] = rid_to_index[result["meta_info"]["id"]]
yield result
new_task = asyncio.create_task(gen.__anext__())
task_map[new_task] = gen
except StopAsyncIteration:
pass
def flush_cache(self):
req = FlushCacheReq()
2024-09-29 02:36:12 -07:00
self.send_to_scheduler.send_pyobj(req)
2024-01-26 13:32:59 +08:00
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
2024-09-29 02:36:12 -07:00
self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
def stop_profile(self):
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
2024-11-29 17:17:00 -08:00
async def update_weights_from_disk(
self,
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
logger.info("Start update_weights. Load format=%s", obj.load_format)
if True:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
return await self._wait_for_model_update_from_disk(obj)
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str]:
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.model_update_result
if result.success:
self.served_model_name = obj.model_path
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else: # self.server_args.dp_size > 1
self.model_update_tmp = []
result = await self.model_update_result
all_success = all([r.success for r in result])
if all_success is True:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
async def init_weights_update_group(
self,
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_distributed(
self,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_tensor(
self,
obj: UpdateWeightsFromTensorReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
self.auto_create_handle_loop()
results = await self.get_weights_by_name_communicator(obj)
all_parameters = [r.parameter for r in results]
if self.server_args.dp_size == 1:
return all_parameters[0]
else:
return all_parameters
2024-11-20 00:36:53 -08:00
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
self.auto_create_handle_loop()
2024-11-20 00:36:53 -08:00
session_id = uuid.uuid4().hex
obj.session_id = session_id
self.send_to_scheduler.send_pyobj(obj)
self.session_futures[session_id] = asyncio.Future()
session_id = await self.session_futures[session_id]
del self.session_futures[session_id]
return session_id
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)
2024-06-08 02:06:52 -07:00
def create_abort_task(self, obj: GenerateReqInput):
2024-05-20 18:41:21 -07:00
# Abort the request if the client is disconnected.
async def abort_request():
await asyncio.sleep(1)
2024-05-20 18:41:21 -07:00
if obj.is_single:
self.abort_request(obj.rid)
else:
for rid in obj.rid:
2024-05-20 18:41:21 -07:00
self.abort_request(rid)
background_tasks = BackgroundTasks()
background_tasks.add_task(abort_request)
return background_tasks
def auto_create_handle_loop(self):
if not self.to_create_loop:
return
self.to_create_loop = False
loop = asyncio.get_event_loop()
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
async def sigterm_watchdog(self):
while not self.gracefully_exit:
2024-12-08 21:17:31 -08:00
await asyncio.sleep(5)
# drain requests
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
2024-11-03 03:52:38 -08:00
f"Gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
else:
break
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(0)
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj: Union[
2024-11-29 17:17:00 -08:00
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()
2024-12-02 02:27:36 -08:00
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
2024-12-08 12:27:13 -08:00
meta_info = {
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
"prompt_tokens": recv_obj.prompt_tokens[i],
}
if getattr(state.obj, "return_logprob", False):
self.convert_logprob_style(
meta_info,
state.obj.top_logprobs_num,
state.obj.return_text_in_logprobs,
recv_obj,
i,
)
2024-12-09 03:05:59 -08:00
if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update(
{
"completion_tokens": recv_obj.completion_tokens[i],
"cached_tokens": recv_obj.cached_tokens[i],
}
)
2024-12-02 02:27:36 -08:00
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
2024-12-09 03:05:59 -08:00
"meta_info": meta_info,
2024-12-02 02:27:36 -08:00
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
2024-12-09 03:05:59 -08:00
"meta_info": meta_info,
2024-12-02 02:27:36 -08:00
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
2024-12-08 12:27:13 -08:00
"meta_info": meta_info,
2024-12-02 02:27:36 -08:00
}
state.out_list.append(out_dict)
2024-12-08 12:27:13 -08:00
state.finished = recv_obj.finished_reasons[i] is not None
2024-12-02 02:27:36 -08:00
state.event.set()
if self.enable_metrics:
2024-12-08 12:27:13 -08:00
completion_tokens = (
recv_obj.completion_tokens[i]
if recv_obj.completion_tokens
else 0
)
2024-12-02 02:27:36 -08:00
if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)
if state.finished:
self.metrics_collector.inc_prompt_tokens(
2024-12-08 12:27:13 -08:00
recv_obj.prompt_tokens[i]
2024-12-02 02:27:36 -08:00
)
self.metrics_collector.inc_generation_tokens(
completion_tokens
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time)
/ completion_tokens
)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
2024-11-07 15:42:47 -08:00
else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
2024-12-02 02:27:36 -08:00
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj)
2024-12-02 02:27:36 -08:00
else:
raise ValueError(f"Invalid object: {recv_obj=}")
2024-05-14 22:40:46 +08:00
def convert_logprob_style(
self,
2024-12-08 12:27:13 -08:00
meta_info: dict,
top_logprobs_num: int,
return_text_in_logprobs: bool,
2024-12-08 12:27:13 -08:00
recv_obj: BatchStrOut,
recv_obj_index: int,
2024-05-14 22:40:46 +08:00
):
2024-12-08 12:27:13 -08:00
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.input_token_logprobs_val[recv_obj_index],
recv_obj.input_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
recv_obj.output_token_logprobs_val[recv_obj_index],
recv_obj.output_token_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
recv_obj_index
]
if top_logprobs_num > 0:
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_top_logprobs_val[recv_obj_index],
recv_obj.input_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
2024-05-12 04:54:07 -07:00
)
2024-12-08 12:27:13 -08:00
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.output_top_logprobs_val[recv_obj_index],
recv_obj.output_top_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
2024-05-12 04:54:07 -07:00
)
2024-07-28 05:22:14 -07:00
def detokenize_logprob_tokens(
2024-12-08 12:27:13 -08:00
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
2024-07-28 05:22:14 -07:00
):
2024-05-12 04:54:07 -07:00
if not decode_to_text:
2024-12-08 12:27:13 -08:00
return [
(logprob, token_id, None)
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
]
else:
assert self.tokenizer is not None
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
2024-05-12 04:54:07 -07:00
2024-12-08 12:27:13 -08:00
def detokenize_top_logprobs_tokens(
self,
token_logprobs_val: List[float],
token_logprobs_idx: List[int],
decode_to_text: bool,
):
2024-07-28 05:22:14 -07:00
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions.
2024-12-08 12:27:13 -08:00
ret = []
for i in range(len(token_logprobs_val)):
if token_logprobs_val[i]:
ret.append(
self.detokenize_logprob_tokens(
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
)
2024-07-28 05:22:14 -07:00
)
2024-12-08 12:27:13 -08:00
else:
ret.append(None)
return ret
class SignalHandler:
def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None):
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True
T = TypeVar("T")
class _Communicator(Generic[T]):
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_future: Optional[asyncio.Future] = None
self._result_values: Optional[List[T]] = None
async def __call__(self, obj):
self._sender.send_pyobj(obj)
self._result_future = asyncio.Future()
self._result_values = []
await self._result_future
result_values = self._result_values
self._result_future = self._result_values = None
return result_values
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_future.set_result(None)