Add update_weights_from_tensor (#2631)
This commit is contained in:
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -109,6 +110,7 @@ app.add_middleware(
|
||||
tokenizer_manager: TokenizerManager = None
|
||||
scheduler_info: Dict = None
|
||||
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@@ -866,6 +868,14 @@ class Engine:
|
||||
tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
)
|
||||
|
||||
def update_weights_from_tensor(self, name, tensor):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
tokenizer_manager.update_weights_from_tensor(obj, None)
|
||||
)
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
Reference in New Issue
Block a user