Add get weights by parameter name for llama (#2266)

This commit is contained in:
Chayenne
2024-11-29 23:36:38 -08:00
committed by GitHub
parent 7d5d1d3d29
commit 7d1485d376
12 changed files with 337 additions and 17 deletions

View File

@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
OpenSessionReqInput,
UpdateWeightFromDiskReqInput,
)
@@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
)
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
if ret is None:
return ORJSONResponse(
{"error": {"message": "Get parameter by name failed"}},
status_code=HTTPStatus.BAD_REQUEST,
)
else:
return ORJSONResponse(ret, status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
@@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
)
@time_func_latency
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
"""Handle a get parameter by name request."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
@@ -938,3 +969,8 @@ class Engine:
async def get_server_info(self):
return await _get_server_info()
def get_weights_by_name(self, name, truncate_size=100):
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
loop = asyncio.get_event_loop()
return loop.run_until_complete(get_weights_by_name_request(obj, None))