Tiny code cleanup in tokenizer_manager.py (#2586)
This commit is contained in:
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
if ret is None:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": "Get parameter by name failed"}},
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
return _create_error_response("Get parameter by name failed")
|
||||
else:
|
||||
return ORJSONResponse(ret, status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
session_id = await tokenizer_manager.open_session(obj, request)
|
||||
return session_id
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
await tokenizer_manager.close_session(obj, request)
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||
@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
|
||||
return await v1_retrieve_file_content(file_id)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
def launch_engine(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
@@ -849,12 +840,10 @@ class Engine:
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_init_group())
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.init_weights_update_group(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
@@ -863,22 +852,16 @@ class Engine:
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
)
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
||||
|
||||
|
||||
class Runtime:
|
||||
|
||||
Reference in New Issue
Block a user