Online weight updates from torch.distributed (#2279)
This commit is contained in:
@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -80,6 +82,7 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
delete_directory,
|
||||
init_custom_process_group,
|
||||
is_port_available,
|
||||
kill_process_tree,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
||||
)
|
||||
|
||||
|
||||
@app.post("/init_weights_update_group")
|
||||
async def init_weights_update_group(
|
||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||
):
|
||||
"""Initialize the parameter update group."""
|
||||
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_distributed")
|
||||
async def update_weights_from_distributed(
|
||||
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
||||
):
|
||||
"""Update model parameter from distributed online."""
|
||||
success, message = await tokenizer_manager.update_weights_from_distributed(
|
||||
obj, request
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Get model parameter by name."""
|
||||
@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
@time_func_latency
|
||||
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Handle a get parameter by name request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
@@ -970,7 +989,51 @@ class Engine:
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address: str,
|
||||
master_port: int,
|
||||
rank_offset: int,
|
||||
world_size: int,
|
||||
group_name: str,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
"""Initialize parameter update group."""
|
||||
obj = InitWeightsUpdateGroupReqInput(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=rank_offset,
|
||||
world_size=world_size,
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(get_weights_by_name_request(obj, None))
|
||||
return loop.run_until_complete(_init_group())
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
|
||||
Reference in New Issue
Block a user