Online weight updates from torch.distributed (#2279)
This commit is contained in:
@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqInput:
|
||||
name: str
|
||||
dtype: str
|
||||
shape: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromDistributedReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqInput:
|
||||
# The master address
|
||||
master_address: str
|
||||
# The master port
|
||||
master_port: int
|
||||
# The rank offset
|
||||
rank_offset: int
|
||||
# The world size
|
||||
world_size: int
|
||||
# The group name
|
||||
group_name: str = "weight_update_group"
|
||||
# The backend
|
||||
backend: str = "nccl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsUpdateGroupReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqInput:
|
||||
name: str
|
||||
|
||||
@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
@@ -516,6 +520,19 @@ class Scheduler:
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.init_weights_update_group(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
InitWeightsUpdateGroupReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
||||
success, message = self.update_weights_from_distributed(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -1378,6 +1395,23 @@ class Scheduler:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
"""Initialize the online model parameter update group."""
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
"""Update the online model parameter."""
|
||||
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import (
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
)
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -456,6 +460,48 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
async def init_weights_update_group(
|
||||
self,
|
||||
obj: InitWeightsUpdateGroupReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> bool:
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
self.init_weights_update_group_result = asyncio.Future()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
result = await self.init_weights_update_group_result
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_distributed(
|
||||
self,
|
||||
obj: UpdateWeightsFromDistributedReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
if not self.model_update_lock.locked():
|
||||
async with self.model_update_lock:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.parameter_update_result = asyncio.Future()
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be for update weights from distributed"
|
||||
result = await self.parameter_update_result
|
||||
return result.success, result.message
|
||||
else:
|
||||
logger.error(
|
||||
f"Another parameter update is in progress in tokenizer manager"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
"Another parameter update is in progress. Please try again later.",
|
||||
)
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -546,7 +592,9 @@ class TokenizerManager:
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
@@ -558,6 +606,12 @@ class TokenizerManager:
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.parameter_update_result.set_result(recv_obj)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.get_weights_by_name_result.set_result(recv_obj)
|
||||
@@ -568,6 +622,12 @@ class TokenizerManager:
|
||||
self.get_weights_by_name_tmp
|
||||
)
|
||||
continue
|
||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
self.init_weights_update_group_result.set_result(recv_obj)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
|
||||
@@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -164,6 +166,25 @@ class TpModelWorker:
|
||||
)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.model_runner.init_weights_update_group(
|
||||
recv_req.master_address,
|
||||
recv_req.master_port,
|
||||
recv_req.rank_offset,
|
||||
recv_req.world_size,
|
||||
recv_req.group_name,
|
||||
recv_req.backend,
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.model_runner.update_weights_from_distributed(
|
||||
recv_req.name, recv_req.dtype, recv_req.shape
|
||||
)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
|
||||
@@ -25,7 +25,9 @@ import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
@@ -211,6 +213,16 @@ class TpModelWorkerClient:
|
||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||
return success, message
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||
):
|
||||
success, message = self.worker.update_weights_from_distributed(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
|
||||
@@ -20,10 +20,13 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Type
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
@@ -59,6 +62,7 @@ from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_hip,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_model_config,
|
||||
@@ -404,6 +408,86 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address,
|
||||
master_port,
|
||||
rank_offset,
|
||||
world_size,
|
||||
group_name,
|
||||
backend="nccl",
|
||||
):
|
||||
"""Initialize the Torch process group for model parameter updates.
|
||||
|
||||
`_model_update_group` is used in the RLHF workflow, where rank
|
||||
0 is the actor model in the training engine, and the other ranks are
|
||||
the inference engine, which is used for rollout.
|
||||
|
||||
In the RLHF workflow, the training engine updates the model
|
||||
weights/parameters online, and broadcasts them to the inference
|
||||
engine through the `_model_update_group` process group.
|
||||
"""
|
||||
assert (
|
||||
torch.distributed.is_initialized()
|
||||
), "Default torch process group must be initialized"
|
||||
assert group_name != "", "Group name cannot be empty"
|
||||
|
||||
rank = rank_offset + self.tp_rank
|
||||
|
||||
logger.info(
|
||||
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
||||
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||
)
|
||||
|
||||
try:
|
||||
self._model_update_group = init_custom_process_group(
|
||||
backend=backend,
|
||||
init_method=f"tcp://{master_address}:{master_port}",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_name=group_name,
|
||||
)
|
||||
dist.barrier(group=self._model_update_group, device_ids=[rank])
|
||||
return True, "Succeeded to initialize custom process group."
|
||||
except Exception as e:
|
||||
message = f"Failed to initialize custom process group: {e}."
|
||||
logger.error(message)
|
||||
return False, message
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""
|
||||
Update specific parameter in the model weights online
|
||||
through `_model_update_group` process group.
|
||||
|
||||
Args:
|
||||
name: the name of the parameter to be updated.
|
||||
dtype: the data type of the parameter to be updated.
|
||||
shape: the shape of the parameter to be updated.
|
||||
"""
|
||||
target_dtype = (
|
||||
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||
)
|
||||
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
||||
|
||||
assert (
|
||||
self._model_update_group is not None
|
||||
), "model update group must be initialized"
|
||||
|
||||
try:
|
||||
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
|
||||
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
|
||||
self.model.load_weights([(name, weights)])
|
||||
return True, f"Succeeded to update parameter {name} online."
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Failed to update parameter online: {e}. "
|
||||
f"The full weights of the ModelRunner are partially updated. "
|
||||
f"Please discard the whole weights."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
@@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
# Llama 3.2 1B Insturct set tie_word_embeddings to True
|
||||
# Llama 3.1 8B Insturct set tie_word_embeddings to False
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
|
||||
@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
@@ -80,6 +82,7 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
delete_directory,
|
||||
init_custom_process_group,
|
||||
is_port_available,
|
||||
kill_process_tree,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
||||
)
|
||||
|
||||
|
||||
@app.post("/init_weights_update_group")
|
||||
async def init_weights_update_group(
|
||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||
):
|
||||
"""Initialize the parameter update group."""
|
||||
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_distributed")
|
||||
async def update_weights_from_distributed(
|
||||
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
||||
):
|
||||
"""Update model parameter from distributed online."""
|
||||
success, message = await tokenizer_manager.update_weights_from_distributed(
|
||||
obj, request
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Get model parameter by name."""
|
||||
@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
@time_func_latency
|
||||
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Handle a get parameter by name request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
@@ -970,7 +989,51 @@ class Engine:
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
def init_weights_update_group(
|
||||
self,
|
||||
master_address: str,
|
||||
master_port: int,
|
||||
rank_offset: int,
|
||||
world_size: int,
|
||||
group_name: str,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
"""Initialize parameter update group."""
|
||||
obj = InitWeightsUpdateGroupReqInput(
|
||||
master_address=master_address,
|
||||
master_port=master_port,
|
||||
rank_offset=rank_offset,
|
||||
world_size=world_size,
|
||||
group_name=group_name,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
async def _init_group():
|
||||
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(get_weights_by_name_request(obj, None))
|
||||
return loop.run_until_complete(_init_group())
|
||||
|
||||
def update_weights_from_distributed(self, name, dtype, shape):
|
||||
"""Update weights from distributed source."""
|
||||
obj = UpdateWeightsFromDistributedReqInput(
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
async def _update_weights():
|
||||
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_update_weights())
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
"""Get weights by parameter name."""
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
|
||||
async def _get_weights():
|
||||
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(_get_weights())
|
||||
|
||||
@@ -39,6 +39,7 @@ import numpy as np
|
||||
import psutil
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import zmq
|
||||
@@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity():
|
||||
)
|
||||
|
||||
|
||||
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
||||
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
||||
def init_custom_process_group(
|
||||
backend=None,
|
||||
init_method=None,
|
||||
timeout=None,
|
||||
world_size=-1,
|
||||
rank=-1,
|
||||
store=None,
|
||||
group_name=None,
|
||||
pg_options=None,
|
||||
):
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
_new_process_group_helper,
|
||||
_world,
|
||||
default_pg_timeout,
|
||||
rendezvous,
|
||||
)
|
||||
|
||||
assert (store is None) or (
|
||||
init_method is None
|
||||
), "Cannot specify both init_method and store."
|
||||
|
||||
if store is not None:
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
assert rank >= 0, "rank must be non-negative if using store"
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
if backend:
|
||||
backend = Backend(backend)
|
||||
else:
|
||||
backend = Backend("undefined")
|
||||
|
||||
if timeout is None:
|
||||
timeout = default_pg_timeout
|
||||
|
||||
# backward compatible API
|
||||
if store is None:
|
||||
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||
store, rank, world_size = next(rendezvous_iterator)
|
||||
store.set_timeout(timeout)
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
store = PrefixStore(group_name, store)
|
||||
|
||||
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
|
||||
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
|
||||
# We need to determine the appropriate parameter name based on PyTorch version
|
||||
pg_options_param_name = (
|
||||
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
||||
)
|
||||
pg, _ = _new_process_group_helper(
|
||||
world_size,
|
||||
rank,
|
||||
[],
|
||||
backend,
|
||||
store,
|
||||
group_name=group_name,
|
||||
**{pg_options_param_name: pg_options},
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||
|
||||
return pg
|
||||
|
||||
|
||||
def crash_on_warnings():
|
||||
# Crash on warning if we are running CI tests
|
||||
return get_bool_env_var("SGLANG_IS_IN_CI")
|
||||
|
||||
Reference in New Issue
Block a user