Add get weights by parameter name for llama (#2266)
This commit is contained in:
@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
@@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Get model parameter by name."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
if ret is None:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": "Get parameter by name failed"}},
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(ret, status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
@@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
@time_func_latency
|
||||
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Handle a get parameter by name request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
@@ -938,3 +969,8 @@ class Engine:
|
||||
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(get_weights_by_name_request(obj, None))
|
||||
|
||||
Reference in New Issue
Block a user