Online weight updates from torch.distributed (#2279)

This commit is contained in:
Chayenne
2024-12-01 23:23:18 -08:00
committed by GitHub
parent 28bc60dcab
commit 983bfcf386
12 changed files with 1120 additions and 61 deletions

View File

@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -80,6 +82,7 @@ from sglang.srt.utils import (
assert_pkg_version,
configure_logger,
delete_directory,
init_custom_process_group,
is_port_available,
kill_process_tree,
maybe_set_triton_cache_manager,
@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
)
@app.post("/init_weights_update_group")
async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput, request: Request
):
"""Initialize the parameter update group."""
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weights_from_distributed")
async def update_weights_from_distributed(
obj: UpdateWeightsFromDistributedReqInput, request: Request
):
"""Update model parameter from distributed online."""
success, message = await tokenizer_manager.update_weights_from_distributed(
obj, request
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name."""
@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
)
@time_func_latency
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
"""Handle a get parameter by name request."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
@@ -970,7 +989,51 @@ class Engine:
async def get_server_info(self):
return await _get_server_info()
def get_weights_by_name(self, name, truncate_size=100):
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
def init_weights_update_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
):
"""Initialize parameter update group."""
obj = InitWeightsUpdateGroupReqInput(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
group_name=group_name,
backend=backend,
)
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(get_weights_by_name_request(obj, None))
return loop.run_until_complete(_init_group())
def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
obj = UpdateWeightsFromDistributedReqInput(
name=name,
dtype=dtype,
shape=shape,
)
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())