Tiny code cleanup in tokenizer_manager.py (#2586)
This commit is contained in:
@@ -22,7 +22,7 @@ import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import fastapi
|
||||
import uvloop
|
||||
@@ -173,6 +173,15 @@ class TokenizerManager:
|
||||
|
||||
# Others
|
||||
self.gracefully_exit = False
|
||||
self.init_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_distributed_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
@@ -190,8 +199,7 @@ class TokenizerManager:
|
||||
):
|
||||
created_time = time.time()
|
||||
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||
raise ValueError(
|
||||
@@ -440,8 +448,7 @@ class TokenizerManager:
|
||||
obj: UpdateWeightFromDiskReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
# default the load format to the server_args
|
||||
if obj.load_format is None:
|
||||
@@ -456,7 +463,7 @@ class TokenizerManager:
|
||||
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
) -> Tuple[bool, str, int]:
|
||||
) -> Tuple[bool, str]:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
@@ -485,15 +492,11 @@ class TokenizerManager:
|
||||
obj: InitWeightsUpdateGroupReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.init_weights_update_group_result = asyncio.Future()
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
result = await self.init_weights_update_group_result
|
||||
result = (await self.init_weights_update_group_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_distributed(
|
||||
@@ -501,44 +504,32 @@ class TokenizerManager:
|
||||
obj: UpdateWeightsFromDistributedReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
self.auto_create_handle_loop()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.parameter_update_result: Awaitable[
|
||||
UpdateWeightsFromDistributedReqOutput
|
||||
] = asyncio.Future()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
result = await self.parameter_update_result
|
||||
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.get_weights_by_name_result = asyncio.Future()
|
||||
self.auto_create_handle_loop()
|
||||
results = await self.get_weights_by_name_communicator(obj)
|
||||
all_parameters = [r.parameter for r in results]
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.get_weights_by_name_result
|
||||
return result.parameter
|
||||
return all_parameters[0]
|
||||
else:
|
||||
self.get_weights_by_name_tmp = []
|
||||
result = await self.get_weights_by_name_result
|
||||
all_parameters = [r.parameter for r in result]
|
||||
return all_parameters
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
session_id = uuid.uuid4().hex
|
||||
obj.session_id = session_id
|
||||
@@ -568,7 +559,7 @@ class TokenizerManager:
|
||||
background_tasks.add_task(abort_request)
|
||||
return background_tasks
|
||||
|
||||
def create_handle_loop(self):
|
||||
def auto_create_handle_loop(self):
|
||||
if not self.to_create_loop:
|
||||
return
|
||||
|
||||
@@ -711,21 +702,14 @@ class TokenizerManager:
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
self.init_weights_update_group_result.set_result(recv_obj)
|
||||
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.parameter_update_result.set_result(recv_obj)
|
||||
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.get_weights_by_name_result.set_result(recv_obj)
|
||||
else:
|
||||
self.get_weights_by_name_tmp.append(recv_obj)
|
||||
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
||||
self.get_weights_by_name_result.set_result(
|
||||
self.get_weights_by_name_tmp
|
||||
)
|
||||
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj=}")
|
||||
|
||||
@@ -809,3 +793,28 @@ class SignalHandler:
|
||||
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||
)
|
||||
self.tokenizer_manager.gracefully_exit = True
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _Communicator(Generic[T]):
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
self._result_future: Optional[asyncio.Future] = None
|
||||
self._result_values: Optional[List[T]] = None
|
||||
|
||||
async def __call__(self, obj):
|
||||
self._sender.send_pyobj(obj)
|
||||
self._result_future = asyncio.Future()
|
||||
self._result_values = []
|
||||
await self._result_future
|
||||
result_values = self._result_values
|
||||
self._result_future = self._result_values = None
|
||||
return result_values
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_future.set_result(None)
|
||||
|
||||
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
if ret is None:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": "Get parameter by name failed"}},
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
return _create_error_response("Get parameter by name failed")
|
||||
else:
|
||||
return ORJSONResponse(ret, status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
session_id = await tokenizer_manager.open_session(obj, request)
|
||||
return session_id
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
await tokenizer_manager.close_session(obj, request)
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
|
||||
return await v1_retrieve_file_content(file_id)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
def launch_engine(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
@@ -849,12 +840,10 @@ class Engine:
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_init_group())
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.init_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
@@ -863,22 +852,16 @@ class Engine:
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
)
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
||||
|
||||
|
||||
class Runtime:
|
||||
|
||||
Reference in New Issue
Block a user