Add get weights by parameter name for llama (#2266)
This commit is contained in:
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -454,6 +456,23 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.get_weights_by_name_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.get_weights_by_name_result
|
||||
return result.parameter
|
||||
else:
|
||||
self.get_weights_by_name_tmp = []
|
||||
result = await self.get_weights_by_name_result
|
||||
all_parameters = [r.parameter for r in result]
|
||||
return all_parameters
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -527,6 +546,7 @@ class TokenizerManager:
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
GetWeightsByNameReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
@@ -538,6 +558,16 @@ class TokenizerManager:
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.get_weights_by_name_result.set_result(recv_obj)
|
||||
else:
|
||||
self.get_weights_by_name_tmp.append(recv_obj)
|
||||
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
||||
self.get_weights_by_name_result.set_result(
|
||||
self.get_weights_by_name_tmp
|
||||
)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
|
||||
Reference in New Issue
Block a user