Tiny code cleanup in tokenizer_manager.py (#2586)

This commit is contained in:
fzyzcjy
2024-12-27 09:53:09 +08:00
committed by GitHub
parent f46f394f4d
commit b2ed5c8ea7
2 changed files with 74 additions and 82 deletions

View File

@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
if ret is None:
return ORJSONResponse(
{"error": {"message": "Get parameter by name failed"}},
status_code=HTTPStatus.BAD_REQUEST,
)
return _create_error_response("Get parameter by name failed")
else:
return ORJSONResponse(ret, status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/open_session", methods=["GET", "POST"])
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
session_id = await tokenizer_manager.open_session(obj, request)
return session_id
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/close_session", methods=["GET", "POST"])
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
await tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
# fastapi implicitly converts json in the request to obj (dataclass)
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
##### OpenAI-compatible API endpoints #####
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
return await v1_retrieve_file_content(file_id)
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
def launch_engine(
server_args: ServerArgs,
):
@@ -849,12 +840,10 @@ class Engine:
group_name=group_name,
backend=backend,
)
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_init_group())
return loop.run_until_complete(
tokenizer_manager.init_weights_update_group(obj, None)
)
def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
@@ -863,22 +852,16 @@ class Engine:
dtype=dtype,
shape=shape,
)
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
return loop.run_until_complete(
tokenizer_manager.update_weights_from_distributed(obj, None)
)
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
class Runtime: