Tiny code cleanup in tokenizer_manager.py (#2586)
This commit is contained in:
@@ -22,7 +22,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvloop
|
import uvloop
|
||||||
@@ -173,6 +173,15 @@ class TokenizerManager:
|
|||||||
|
|
||||||
# Others
|
# Others
|
||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
|
self.init_weights_update_group_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
|
self.update_weights_from_distributed_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
|
self.get_weights_by_name_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
@@ -190,8 +199,7 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
created_time = time.time()
|
created_time = time.time()
|
||||||
|
|
||||||
if self.to_create_loop:
|
self.auto_create_handle_loop()
|
||||||
self.create_handle_loop()
|
|
||||||
|
|
||||||
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -440,8 +448,7 @@ class TokenizerManager:
|
|||||||
obj: UpdateWeightFromDiskReqInput,
|
obj: UpdateWeightFromDiskReqInput,
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
if self.to_create_loop:
|
self.auto_create_handle_loop()
|
||||||
self.create_handle_loop()
|
|
||||||
|
|
||||||
# default the load format to the server_args
|
# default the load format to the server_args
|
||||||
if obj.load_format is None:
|
if obj.load_format is None:
|
||||||
@@ -456,7 +463,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def _wait_for_model_update_from_disk(
|
async def _wait_for_model_update_from_disk(
|
||||||
self, obj: UpdateWeightFromDiskReqInput
|
self, obj: UpdateWeightFromDiskReqInput
|
||||||
) -> Tuple[bool, str, int]:
|
) -> Tuple[bool, str]:
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
self.model_update_result = asyncio.Future()
|
self.model_update_result = asyncio.Future()
|
||||||
if self.server_args.dp_size == 1:
|
if self.server_args.dp_size == 1:
|
||||||
@@ -485,15 +492,11 @@ class TokenizerManager:
|
|||||||
obj: InitWeightsUpdateGroupReqInput,
|
obj: InitWeightsUpdateGroupReqInput,
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
if self.to_create_loop:
|
self.auto_create_handle_loop()
|
||||||
self.create_handle_loop()
|
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
|
||||||
|
|
||||||
self.init_weights_update_group_result = asyncio.Future()
|
|
||||||
assert (
|
assert (
|
||||||
self.server_args.dp_size == 1
|
self.server_args.dp_size == 1
|
||||||
), "dp_size must be 1 for init parameter update group"
|
), "dp_size must be 1 for init parameter update group"
|
||||||
result = await self.init_weights_update_group_result
|
result = (await self.init_weights_update_group_communicator(obj))[0]
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
|
||||||
async def update_weights_from_distributed(
|
async def update_weights_from_distributed(
|
||||||
@@ -501,44 +504,32 @@ class TokenizerManager:
|
|||||||
obj: UpdateWeightsFromDistributedReqInput,
|
obj: UpdateWeightsFromDistributedReqInput,
|
||||||
request: Optional[fastapi.Request] = None,
|
request: Optional[fastapi.Request] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
if self.to_create_loop:
|
self.auto_create_handle_loop()
|
||||||
self.create_handle_loop()
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be for update weights from distributed"
|
||||||
|
|
||||||
# This means that weight sync
|
# This means that weight sync
|
||||||
# cannot run while requests are in progress.
|
# cannot run while requests are in progress.
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
result = (await self.update_weights_from_distributed_communicator(obj))[0]
|
||||||
self.parameter_update_result: Awaitable[
|
|
||||||
UpdateWeightsFromDistributedReqOutput
|
|
||||||
] = asyncio.Future()
|
|
||||||
assert (
|
|
||||||
self.server_args.dp_size == 1
|
|
||||||
), "dp_size must be for update weights from distributed"
|
|
||||||
result = await self.parameter_update_result
|
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
|
||||||
async def get_weights_by_name(
|
async def get_weights_by_name(
|
||||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
if self.to_create_loop:
|
self.auto_create_handle_loop()
|
||||||
self.create_handle_loop()
|
results = await self.get_weights_by_name_communicator(obj)
|
||||||
|
all_parameters = [r.parameter for r in results]
|
||||||
self.send_to_scheduler.send_pyobj(obj)
|
|
||||||
self.get_weights_by_name_result = asyncio.Future()
|
|
||||||
if self.server_args.dp_size == 1:
|
if self.server_args.dp_size == 1:
|
||||||
result = await self.get_weights_by_name_result
|
return all_parameters[0]
|
||||||
return result.parameter
|
|
||||||
else:
|
else:
|
||||||
self.get_weights_by_name_tmp = []
|
|
||||||
result = await self.get_weights_by_name_result
|
|
||||||
all_parameters = [r.parameter for r in result]
|
|
||||||
return all_parameters
|
return all_parameters
|
||||||
|
|
||||||
async def open_session(
|
async def open_session(
|
||||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
if self.to_create_loop:
|
self.auto_create_handle_loop()
|
||||||
self.create_handle_loop()
|
|
||||||
|
|
||||||
session_id = uuid.uuid4().hex
|
session_id = uuid.uuid4().hex
|
||||||
obj.session_id = session_id
|
obj.session_id = session_id
|
||||||
@@ -568,7 +559,7 @@ class TokenizerManager:
|
|||||||
background_tasks.add_task(abort_request)
|
background_tasks.add_task(abort_request)
|
||||||
return background_tasks
|
return background_tasks
|
||||||
|
|
||||||
def create_handle_loop(self):
|
def auto_create_handle_loop(self):
|
||||||
if not self.to_create_loop:
|
if not self.to_create_loop:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -711,21 +702,14 @@ class TokenizerManager:
|
|||||||
assert (
|
assert (
|
||||||
self.server_args.dp_size == 1
|
self.server_args.dp_size == 1
|
||||||
), "dp_size must be 1 for init parameter update group"
|
), "dp_size must be 1 for init parameter update group"
|
||||||
self.init_weights_update_group_result.set_result(recv_obj)
|
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
||||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||||
assert (
|
assert (
|
||||||
self.server_args.dp_size == 1
|
self.server_args.dp_size == 1
|
||||||
), "dp_size must be 1 for update weights from distributed"
|
), "dp_size must be 1 for update weights from distributed"
|
||||||
self.parameter_update_result.set_result(recv_obj)
|
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
||||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||||
if self.server_args.dp_size == 1:
|
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
||||||
self.get_weights_by_name_result.set_result(recv_obj)
|
|
||||||
else:
|
|
||||||
self.get_weights_by_name_tmp.append(recv_obj)
|
|
||||||
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
|
||||||
self.get_weights_by_name_result.set_result(
|
|
||||||
self.get_weights_by_name_tmp
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid object: {recv_obj=}")
|
raise ValueError(f"Invalid object: {recv_obj=}")
|
||||||
|
|
||||||
@@ -809,3 +793,28 @@ class SignalHandler:
|
|||||||
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
||||||
)
|
)
|
||||||
self.tokenizer_manager.gracefully_exit = True
|
self.tokenizer_manager.gracefully_exit = True
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class _Communicator(Generic[T]):
|
||||||
|
def __init__(self, sender, fan_out: int):
|
||||||
|
self._sender = sender
|
||||||
|
self._fan_out = fan_out
|
||||||
|
self._result_future: Optional[asyncio.Future] = None
|
||||||
|
self._result_values: Optional[List[T]] = None
|
||||||
|
|
||||||
|
async def __call__(self, obj):
|
||||||
|
self._sender.send_pyobj(obj)
|
||||||
|
self._result_future = asyncio.Future()
|
||||||
|
self._result_values = []
|
||||||
|
await self._result_future
|
||||||
|
result_values = self._result_values
|
||||||
|
self._result_future = self._result_values = None
|
||||||
|
return result_values
|
||||||
|
|
||||||
|
def handle_recv(self, recv_obj: T):
|
||||||
|
self._result_values.append(recv_obj)
|
||||||
|
if len(self._result_values) == self._fan_out:
|
||||||
|
self._result_future.set_result(None)
|
||||||
|
|||||||
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||||
if ret is None:
|
if ret is None:
|
||||||
return ORJSONResponse(
|
return _create_error_response("Get parameter by name failed")
|
||||||
{"error": {"message": "Get parameter by name failed"}},
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return ORJSONResponse(ret, status_code=200)
|
return ORJSONResponse(ret, status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ORJSONResponse(
|
return _create_error_response(e)
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||||
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
|||||||
session_id = await tokenizer_manager.open_session(obj, request)
|
session_id = await tokenizer_manager.open_session(obj, request)
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ORJSONResponse(
|
return _create_error_response(e)
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||||
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|||||||
await tokenizer_manager.close_session(obj, request)
|
await tokenizer_manager.close_session(obj, request)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ORJSONResponse(
|
return _create_error_response(e)
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||||
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
return ret
|
return ret
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error: {e}")
|
logger.error(f"Error: {e}")
|
||||||
return ORJSONResponse(
|
return _create_error_response(e)
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||||
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
return ret
|
return ret
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return ORJSONResponse(
|
return _create_error_response(e)
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||||
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|||||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
return ret
|
return ret
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return ORJSONResponse(
|
return _create_error_response(e)
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
##### OpenAI-compatible API endpoints #####
|
##### OpenAI-compatible API endpoints #####
|
||||||
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
|
|||||||
return await v1_retrieve_file_content(file_id)
|
return await v1_retrieve_file_content(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_error_response(e):
|
||||||
|
return ORJSONResponse(
|
||||||
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def launch_engine(
|
def launch_engine(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
):
|
):
|
||||||
@@ -849,12 +840,10 @@ class Engine:
|
|||||||
group_name=group_name,
|
group_name=group_name,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _init_group():
|
|
||||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(_init_group())
|
return loop.run_until_complete(
|
||||||
|
tokenizer_manager.init_weights_update_group(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
def update_weights_from_distributed(self, name, dtype, shape):
|
def update_weights_from_distributed(self, name, dtype, shape):
|
||||||
"""Update weights from distributed source."""
|
"""Update weights from distributed source."""
|
||||||
@@ -863,22 +852,16 @@ class Engine:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
shape=shape,
|
shape=shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_weights():
|
|
||||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(_update_weights())
|
return loop.run_until_complete(
|
||||||
|
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
def get_weights_by_name(self, name, truncate_size=100):
|
def get_weights_by_name(self, name, truncate_size=100):
|
||||||
"""Get weights by parameter name."""
|
"""Get weights by parameter name."""
|
||||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||||
|
|
||||||
async def _get_weights():
|
|
||||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(_get_weights())
|
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
|
|||||||
Reference in New Issue
Block a user