Online weight updates from torch.distributed (#2279)
This commit is contained in:
7
.github/workflows/pr-test.yml
vendored
7
.github/workflows/pr-test.yml
vendored
@@ -27,6 +27,7 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
||||||
unit-test-frontend:
|
unit-test-frontend:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
runs-on: 1-gpu-runner
|
runs-on: 1-gpu-runner
|
||||||
@@ -98,6 +99,11 @@ jobs:
|
|||||||
python3 test_mla_fp8.py
|
python3 test_mla_fp8.py
|
||||||
python3 test_dp_attention.py
|
python3 test_dp_attention.py
|
||||||
|
|
||||||
|
- name: Test update weights from distributed
|
||||||
|
timeout-minutes: 10
|
||||||
|
run: |
|
||||||
|
cd test/srt
|
||||||
|
python3 test_update_weights_from_distributed.py
|
||||||
|
|
||||||
performance-test-1-gpu-part-1:
|
performance-test-1-gpu-part-1:
|
||||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||||
@@ -245,6 +251,7 @@ jobs:
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_moe_eval_accuracy_large.py
|
python3 test_moe_eval_accuracy_large.py
|
||||||
|
|
||||||
|
|
||||||
finish:
|
finish:
|
||||||
needs: [
|
needs: [
|
||||||
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
|
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
|
||||||
|
|||||||
@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput:
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightsFromDistributedReqInput:
|
||||||
|
name: str
|
||||||
|
dtype: str
|
||||||
|
shape: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightsFromDistributedReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InitWeightsUpdateGroupReqInput:
|
||||||
|
# The master address
|
||||||
|
master_address: str
|
||||||
|
# The master port
|
||||||
|
master_port: int
|
||||||
|
# The rank offset
|
||||||
|
rank_offset: int
|
||||||
|
# The world size
|
||||||
|
world_size: int
|
||||||
|
# The group name
|
||||||
|
group_name: str = "weight_update_group"
|
||||||
|
# The backend
|
||||||
|
backend: str = "nccl"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InitWeightsUpdateGroupReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GetWeightsByNameReqInput:
|
class GetWeightsByNameReqInput:
|
||||||
name: str
|
name: str
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
InitWeightsUpdateGroupReqOutput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
@@ -516,6 +520,19 @@ class Scheduler:
|
|||||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||||
parameter = self.get_weights_by_name(recv_req)
|
parameter = self.get_weights_by_name(recv_req)
|
||||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||||
|
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
||||||
|
success, message = self.init_weights_update_group(recv_req)
|
||||||
|
self.send_to_tokenizer.send_pyobj(
|
||||||
|
InitWeightsUpdateGroupReqOutput(success, message)
|
||||||
|
)
|
||||||
|
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
||||||
|
success, message = self.update_weights_from_distributed(recv_req)
|
||||||
|
self.send_to_tokenizer.send_pyobj(
|
||||||
|
UpdateWeightsFromDistributedReqOutput(success, message)
|
||||||
|
)
|
||||||
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||||
|
parameter = self.get_weights_by_name(recv_req)
|
||||||
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||||
elif isinstance(recv_req, ProfileReq):
|
elif isinstance(recv_req, ProfileReq):
|
||||||
if recv_req == ProfileReq.START_PROFILE:
|
if recv_req == ProfileReq.START_PROFILE:
|
||||||
self.start_profile()
|
self.start_profile()
|
||||||
@@ -1378,6 +1395,23 @@ class Scheduler:
|
|||||||
logger.error(message)
|
logger.error(message)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
|
"""Initialize the online model parameter update group."""
|
||||||
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_distributed(
|
||||||
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||||
|
):
|
||||||
|
"""Update the online model parameter."""
|
||||||
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
||||||
|
if success:
|
||||||
|
flash_cache_success = self.flush_cache()
|
||||||
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||||
|
else:
|
||||||
|
logger.error(message)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||||
return parameter
|
return parameter
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
InitWeightsUpdateGroupReqOutput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -456,6 +460,48 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
return False, "Another update is in progress. Please try again later."
|
return False, "Another update is in progress. Please try again later."
|
||||||
|
|
||||||
|
async def init_weights_update_group(
|
||||||
|
self,
|
||||||
|
obj: InitWeightsUpdateGroupReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
) -> bool:
|
||||||
|
if self.to_create_loop:
|
||||||
|
self.create_handle_loop()
|
||||||
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
|
self.init_weights_update_group_result = asyncio.Future()
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for init parameter update group"
|
||||||
|
result = await self.init_weights_update_group_result
|
||||||
|
return result.success, result.message
|
||||||
|
|
||||||
|
async def update_weights_from_distributed(
|
||||||
|
self,
|
||||||
|
obj: UpdateWeightsFromDistributedReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
):
|
||||||
|
if self.to_create_loop:
|
||||||
|
self.create_handle_loop()
|
||||||
|
|
||||||
|
if not self.model_update_lock.locked():
|
||||||
|
async with self.model_update_lock:
|
||||||
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
self.parameter_update_result = asyncio.Future()
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be for update weights from distributed"
|
||||||
|
result = await self.parameter_update_result
|
||||||
|
return result.success, result.message
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Another parameter update is in progress in tokenizer manager"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"Another parameter update is in progress. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
async def get_weights_by_name(
|
async def get_weights_by_name(
|
||||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
@@ -546,7 +592,9 @@ class TokenizerManager:
|
|||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
|
InitWeightsUpdateGroupReqOutput,
|
||||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
|
|
||||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||||
@@ -558,6 +606,12 @@ class TokenizerManager:
|
|||||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||||
self.model_update_result.set_result(self.model_update_tmp)
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for update weights from distributed"
|
||||||
|
self.parameter_update_result.set_result(recv_obj)
|
||||||
|
continue
|
||||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||||
if self.server_args.dp_size == 1:
|
if self.server_args.dp_size == 1:
|
||||||
self.get_weights_by_name_result.set_result(recv_obj)
|
self.get_weights_by_name_result.set_result(recv_obj)
|
||||||
@@ -568,6 +622,12 @@ class TokenizerManager:
|
|||||||
self.get_weights_by_name_tmp
|
self.get_weights_by_name_tmp
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for init parameter update group"
|
||||||
|
self.init_weights_update_group_result.set_result(recv_obj)
|
||||||
|
continue
|
||||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
recv_obj.session_id
|
recv_obj.session_id
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig
|
|||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -164,6 +166,25 @@ class TpModelWorker:
|
|||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
|
success, message = self.model_runner.init_weights_update_group(
|
||||||
|
recv_req.master_address,
|
||||||
|
recv_req.master_port,
|
||||||
|
recv_req.rank_offset,
|
||||||
|
recv_req.world_size,
|
||||||
|
recv_req.group_name,
|
||||||
|
recv_req.backend,
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_distributed(
|
||||||
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||||
|
):
|
||||||
|
success, message = self.model_runner.update_weights_from_distributed(
|
||||||
|
recv_req.name, recv_req.dtype, recv_req.shape
|
||||||
|
)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
parameter = self.model_runner.get_weights_by_name(
|
parameter = self.model_runner.get_weights_by_name(
|
||||||
recv_req.name, recv_req.truncate_size
|
recv_req.name, recv_req.truncate_size
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
@@ -211,6 +213,16 @@ class TpModelWorkerClient:
|
|||||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
|
success, message = self.worker.init_weights_update_group(recv_req)
|
||||||
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_distributed(
|
||||||
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
||||||
|
):
|
||||||
|
success, message = self.worker.update_weights_from_distributed(recv_req)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
return self.worker.get_weights_by_name(recv_req)
|
return self.worker.get_weights_by_name(recv_req)
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,13 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
import time
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Type
|
from tokenize import tabsize
|
||||||
|
from typing import Any, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.config import DeviceConfig, LoadConfig
|
from vllm.config import DeviceConfig, LoadConfig
|
||||||
from vllm.config import ModelConfig as VllmModelConfig
|
from vllm.config import ModelConfig as VllmModelConfig
|
||||||
@@ -59,6 +62,7 @@ from sglang.srt.utils import (
|
|||||||
crash_on_warnings,
|
crash_on_warnings,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
|
init_custom_process_group,
|
||||||
is_hip,
|
is_hip,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
monkey_patch_vllm_model_config,
|
monkey_patch_vllm_model_config,
|
||||||
@@ -404,6 +408,86 @@ class ModelRunner:
|
|||||||
logger.info("Update weights end.")
|
logger.info("Update weights end.")
|
||||||
return True, "Succeeded to update model weights."
|
return True, "Succeeded to update model weights."
|
||||||
|
|
||||||
|
def init_weights_update_group(
|
||||||
|
self,
|
||||||
|
master_address,
|
||||||
|
master_port,
|
||||||
|
rank_offset,
|
||||||
|
world_size,
|
||||||
|
group_name,
|
||||||
|
backend="nccl",
|
||||||
|
):
|
||||||
|
"""Initialize the Torch process group for model parameter updates.
|
||||||
|
|
||||||
|
`_model_update_group` is used in the RLHF workflow, where rank
|
||||||
|
0 is the actor model in the training engine, and the other ranks are
|
||||||
|
the inference engine, which is used for rollout.
|
||||||
|
|
||||||
|
In the RLHF workflow, the training engine updates the model
|
||||||
|
weights/parameters online, and broadcasts them to the inference
|
||||||
|
engine through the `_model_update_group` process group.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
torch.distributed.is_initialized()
|
||||||
|
), "Default torch process group must be initialized"
|
||||||
|
assert group_name != "", "Group name cannot be empty"
|
||||||
|
|
||||||
|
rank = rank_offset + self.tp_rank
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
||||||
|
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._model_update_group = init_custom_process_group(
|
||||||
|
backend=backend,
|
||||||
|
init_method=f"tcp://{master_address}:{master_port}",
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
group_name=group_name,
|
||||||
|
)
|
||||||
|
dist.barrier(group=self._model_update_group, device_ids=[rank])
|
||||||
|
return True, "Succeeded to initialize custom process group."
|
||||||
|
except Exception as e:
|
||||||
|
message = f"Failed to initialize custom process group: {e}."
|
||||||
|
logger.error(message)
|
||||||
|
return False, message
|
||||||
|
|
||||||
|
def update_weights_from_distributed(self, name, dtype, shape):
|
||||||
|
"""
|
||||||
|
Update specific parameter in the model weights online
|
||||||
|
through `_model_update_group` process group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: the name of the parameter to be updated.
|
||||||
|
dtype: the data type of the parameter to be updated.
|
||||||
|
shape: the shape of the parameter to be updated.
|
||||||
|
"""
|
||||||
|
target_dtype = (
|
||||||
|
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
||||||
|
)
|
||||||
|
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self._model_update_group is not None
|
||||||
|
), "model update group must be initialized"
|
||||||
|
|
||||||
|
try:
|
||||||
|
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
|
||||||
|
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
|
||||||
|
self.model.load_weights([(name, weights)])
|
||||||
|
return True, f"Succeeded to update parameter {name} online."
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = (
|
||||||
|
f"Failed to update parameter online: {e}. "
|
||||||
|
f"The full weights of the ModelRunner are partially updated. "
|
||||||
|
f"Please discard the whole weights."
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
def get_weights_by_name(
|
def get_weights_by_name(
|
||||||
self, name: str, truncate_size: int = 100
|
self, name: str, truncate_size: int = 100
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
|
|||||||
@@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.torchao_config = global_server_args_dict["torchao_config"]
|
self.torchao_config = global_server_args_dict["torchao_config"]
|
||||||
self.model = LlamaModel(config, quant_config=quant_config)
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
|
# Llama 3.2 1B Insturct set tie_word_embeddings to True
|
||||||
|
# Llama 3.1 8B Insturct set tie_word_embeddings to False
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head = self.model.embed_tokens
|
self.lm_head = self.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
|
InitWeightsUpdateGroupReqInput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
|
UpdateWeightsFromDistributedReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
@@ -80,6 +82,7 @@ from sglang.srt.utils import (
|
|||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
delete_directory,
|
delete_directory,
|
||||||
|
init_custom_process_group,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/init_weights_update_group")
|
||||||
|
async def init_weights_update_group(
|
||||||
|
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||||
|
):
|
||||||
|
"""Initialize the parameter update group."""
|
||||||
|
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
return ORJSONResponse(content, status_code=200)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update_weights_from_distributed")
|
||||||
|
async def update_weights_from_distributed(
|
||||||
|
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
||||||
|
):
|
||||||
|
"""Update model parameter from distributed online."""
|
||||||
|
success, message = await tokenizer_manager.update_weights_from_distributed(
|
||||||
|
obj, request
|
||||||
|
)
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
return ORJSONResponse(content, status_code=200)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||||
"""Get model parameter by name."""
|
"""Get model parameter by name."""
|
||||||
@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@time_func_latency
|
|
||||||
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
|
|
||||||
"""Handle a get parameter by name request."""
|
|
||||||
try:
|
|
||||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
|
||||||
return ret
|
|
||||||
except ValueError as e:
|
|
||||||
return ORJSONResponse(
|
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||||
@time_func_latency
|
@time_func_latency
|
||||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
@@ -970,7 +989,51 @@ class Engine:
|
|||||||
async def get_server_info(self):
|
async def get_server_info(self):
|
||||||
return await _get_server_info()
|
return await _get_server_info()
|
||||||
|
|
||||||
def get_weights_by_name(self, name, truncate_size=100):
|
def init_weights_update_group(
|
||||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
self,
|
||||||
|
master_address: str,
|
||||||
|
master_port: int,
|
||||||
|
rank_offset: int,
|
||||||
|
world_size: int,
|
||||||
|
group_name: str,
|
||||||
|
backend: str = "nccl",
|
||||||
|
):
|
||||||
|
"""Initialize parameter update group."""
|
||||||
|
obj = InitWeightsUpdateGroupReqInput(
|
||||||
|
master_address=master_address,
|
||||||
|
master_port=master_port,
|
||||||
|
rank_offset=rank_offset,
|
||||||
|
world_size=world_size,
|
||||||
|
group_name=group_name,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _init_group():
|
||||||
|
return await tokenizer_manager.init_weights_update_group(obj, None)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
return loop.run_until_complete(get_weights_by_name_request(obj, None))
|
return loop.run_until_complete(_init_group())
|
||||||
|
|
||||||
|
def update_weights_from_distributed(self, name, dtype, shape):
|
||||||
|
"""Update weights from distributed source."""
|
||||||
|
obj = UpdateWeightsFromDistributedReqInput(
|
||||||
|
name=name,
|
||||||
|
dtype=dtype,
|
||||||
|
shape=shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _update_weights():
|
||||||
|
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(_update_weights())
|
||||||
|
|
||||||
|
def get_weights_by_name(self, name, truncate_size=100):
|
||||||
|
"""Get weights by parameter name."""
|
||||||
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||||
|
|
||||||
|
async def _get_weights():
|
||||||
|
return await tokenizer_manager.get_weights_by_name(obj, None)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(_get_weights())
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ import numpy as np
|
|||||||
import psutil
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import triton
|
import triton
|
||||||
import zmq
|
import zmq
|
||||||
@@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
|
||||||
|
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
|
||||||
|
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
|
||||||
|
def init_custom_process_group(
|
||||||
|
backend=None,
|
||||||
|
init_method=None,
|
||||||
|
timeout=None,
|
||||||
|
world_size=-1,
|
||||||
|
rank=-1,
|
||||||
|
store=None,
|
||||||
|
group_name=None,
|
||||||
|
pg_options=None,
|
||||||
|
):
|
||||||
|
from torch.distributed.distributed_c10d import (
|
||||||
|
Backend,
|
||||||
|
PrefixStore,
|
||||||
|
_new_process_group_helper,
|
||||||
|
_world,
|
||||||
|
default_pg_timeout,
|
||||||
|
rendezvous,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (store is None) or (
|
||||||
|
init_method is None
|
||||||
|
), "Cannot specify both init_method and store."
|
||||||
|
|
||||||
|
if store is not None:
|
||||||
|
assert world_size > 0, "world_size must be positive if using store"
|
||||||
|
assert rank >= 0, "rank must be non-negative if using store"
|
||||||
|
elif init_method is None:
|
||||||
|
init_method = "env://"
|
||||||
|
|
||||||
|
if backend:
|
||||||
|
backend = Backend(backend)
|
||||||
|
else:
|
||||||
|
backend = Backend("undefined")
|
||||||
|
|
||||||
|
if timeout is None:
|
||||||
|
timeout = default_pg_timeout
|
||||||
|
|
||||||
|
# backward compatible API
|
||||||
|
if store is None:
|
||||||
|
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
|
||||||
|
store, rank, world_size = next(rendezvous_iterator)
|
||||||
|
store.set_timeout(timeout)
|
||||||
|
|
||||||
|
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||||
|
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||||
|
store = PrefixStore(group_name, store)
|
||||||
|
|
||||||
|
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
|
||||||
|
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
|
||||||
|
# We need to determine the appropriate parameter name based on PyTorch version
|
||||||
|
pg_options_param_name = (
|
||||||
|
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
|
||||||
|
)
|
||||||
|
pg, _ = _new_process_group_helper(
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
[],
|
||||||
|
backend,
|
||||||
|
store,
|
||||||
|
group_name=group_name,
|
||||||
|
**{pg_options_param_name: pg_options},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
|
||||||
|
|
||||||
|
return pg
|
||||||
|
|
||||||
|
|
||||||
def crash_on_warnings():
|
def crash_on_warnings():
|
||||||
# Crash on warning if we are running CI tests
|
# Crash on warning if we are running CI tests
|
||||||
return get_bool_env_var("SGLANG_IS_IN_CI")
|
return get_bool_env_var("SGLANG_IS_IN_CI")
|
||||||
|
|||||||
@@ -8,47 +8,46 @@ from transformers import AutoModelForCausalLM
|
|||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
is_in_ci,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
from sglang.utils import terminate_process
|
from sglang.utils import terminate_process
|
||||||
|
|
||||||
|
|
||||||
|
def _process_return(ret):
|
||||||
|
if isinstance(ret, list) and len(ret) == 2:
|
||||||
|
print(f"running assert_allclose on data parallel")
|
||||||
|
np.testing.assert_allclose(ret[0], ret[1])
|
||||||
|
return np.array(ret[0])
|
||||||
|
return np.array(ret)
|
||||||
|
|
||||||
|
|
||||||
class TestGetWeightsByName(unittest.TestCase):
|
class TestGetWeightsByName(unittest.TestCase):
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
def init_hf_model(self, model_name, tie_word_embeddings):
|
||||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
self.hf_model = AutoModelForCausalLM.from_pretrained(
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings
|
||||||
cls.hf_model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
cls.model, torch_dtype="bfloat16"
|
|
||||||
).to("cuda:0")
|
).to("cuda:0")
|
||||||
|
|
||||||
@classmethod
|
def init_backend(self, backend, dp, tp, model_name):
|
||||||
def tearDownClass(cls):
|
|
||||||
del cls.hf_model
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def init_backend(self, backend, dp, tp):
|
|
||||||
self.engine = None
|
|
||||||
self.process = None
|
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.dp = dp
|
self.dp = dp
|
||||||
self.tp = tp
|
self.tp = tp
|
||||||
if backend == "Engine":
|
if backend == "Engine":
|
||||||
self.engine = sgl.Engine(
|
self.engine = sgl.Engine(
|
||||||
model_path=self.model,
|
model_path=model_name,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
tp_size=self.tp,
|
tp_size=tp,
|
||||||
dp_size=self.dp,
|
dp_size=dp,
|
||||||
mem_fraction_static=0.85,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.process = popen_launch_server(
|
self.process = popen_launch_server(
|
||||||
self.model,
|
model_name,
|
||||||
self.base_url,
|
DEFAULT_URL_FOR_TEST,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=(
|
other_args=(
|
||||||
"--tp-size",
|
"--tp-size",
|
||||||
@@ -58,12 +57,50 @@ class TestGetWeightsByName(unittest.TestCase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def close_engine_and_server(self):
|
def clean_up(self):
|
||||||
if self.engine:
|
del self.hf_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if self.backend == "Engine":
|
||||||
self.engine.shutdown()
|
self.engine.shutdown()
|
||||||
if self.process:
|
else:
|
||||||
terminate_process(self.process)
|
terminate_process(self.process)
|
||||||
|
|
||||||
|
def assert_tie_word_embeddings(self, truncate_size):
|
||||||
|
print(f"assert_tie_word_embeddings")
|
||||||
|
if self.backend == "Engine":
|
||||||
|
backend_ret = _process_return(
|
||||||
|
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
backend_ret = _process_return(
|
||||||
|
requests.get(
|
||||||
|
f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
|
||||||
|
json={"name": "lm_head.weight", "truncate_size": truncate_size},
|
||||||
|
).json()
|
||||||
|
)
|
||||||
|
print(f"assert_tie_word_embeddings of hf and backend")
|
||||||
|
assert np.allclose(
|
||||||
|
self.hf_model.get_parameter("model.embed_tokens.weight")
|
||||||
|
.cpu()
|
||||||
|
.detach()
|
||||||
|
.float()
|
||||||
|
.numpy()[:truncate_size],
|
||||||
|
backend_ret,
|
||||||
|
)
|
||||||
|
assert np.allclose(
|
||||||
|
self.hf_model.get_parameter("lm_head.weight")
|
||||||
|
.cpu()
|
||||||
|
.detach()
|
||||||
|
.float()
|
||||||
|
.numpy()[:truncate_size],
|
||||||
|
self.hf_model.get_parameter("model.embed_tokens.weight")
|
||||||
|
.cpu()
|
||||||
|
.detach()
|
||||||
|
.float()
|
||||||
|
.numpy()[:truncate_size],
|
||||||
|
)
|
||||||
|
|
||||||
def assert_weights_all_close(self, param_name, truncate_size):
|
def assert_weights_all_close(self, param_name, truncate_size):
|
||||||
print(
|
print(
|
||||||
f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
|
f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
|
||||||
@@ -73,34 +110,38 @@ class TestGetWeightsByName(unittest.TestCase):
|
|||||||
|
|
||||||
if self.backend == "Engine":
|
if self.backend == "Engine":
|
||||||
engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
|
engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
|
||||||
engine_ret = self._process_return(engine_ret)
|
engine_ret = _process_return(engine_ret)
|
||||||
np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)
|
np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
if self.backend == "Runtime":
|
if self.backend == "Runtime":
|
||||||
runtime_ret = requests.get(
|
runtime_ret = requests.get(
|
||||||
f"{self.base_url}/get_weights_by_name",
|
f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
|
||||||
json={"name": param_name, "truncate_size": truncate_size},
|
json={"name": param_name, "truncate_size": truncate_size},
|
||||||
).json()
|
).json()
|
||||||
runtime_ret = self._process_return(runtime_ret)
|
runtime_ret = _process_return(runtime_ret)
|
||||||
np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)
|
np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
@staticmethod
|
def test_get_weights_by_name(self):
|
||||||
def _process_return(ret):
|
if is_in_ci():
|
||||||
if isinstance(ret, list) and len(ret) == 2:
|
test_suits = [
|
||||||
print("running assert_allclose on data parallel")
|
("Engine", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
|
||||||
np.testing.assert_allclose(ret[0], ret[1])
|
]
|
||||||
return np.array(ret[0])
|
else:
|
||||||
return np.array(ret)
|
test_suits = [
|
||||||
|
("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
|
||||||
|
("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST),
|
||||||
|
]
|
||||||
|
if torch.cuda.device_count() >= 2:
|
||||||
|
test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST))
|
||||||
|
test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST))
|
||||||
|
|
||||||
def test_get_parameters_by_name(self):
|
if torch.cuda.device_count() >= 4:
|
||||||
test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)]
|
test_suits.extend(
|
||||||
|
[
|
||||||
if torch.cuda.device_count() >= 2:
|
("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
|
||||||
test_suits.append(("Engine", 1, 2))
|
("Runtime", 2, 2, DEFAULT_MODEL_NAME_FOR_TEST),
|
||||||
test_suits.append(("Runtime", 2, 1))
|
]
|
||||||
|
)
|
||||||
if torch.cuda.device_count() >= 4:
|
|
||||||
test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)])
|
|
||||||
|
|
||||||
parameters = [
|
parameters = [
|
||||||
"model.embed_tokens.weight",
|
"model.embed_tokens.weight",
|
||||||
@@ -117,11 +158,24 @@ class TestGetWeightsByName(unittest.TestCase):
|
|||||||
"lm_head.weight",
|
"lm_head.weight",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
truncate_size = 100
|
||||||
|
|
||||||
for test_suit in test_suits:
|
for test_suit in test_suits:
|
||||||
|
if test_suit[-1] == DEFAULT_MODEL_NAME_FOR_TEST:
|
||||||
|
tie_word_embeddings = False
|
||||||
|
else:
|
||||||
|
tie_word_embeddings = True
|
||||||
|
|
||||||
|
self.init_hf_model(test_suit[-1], tie_word_embeddings)
|
||||||
self.init_backend(*test_suit)
|
self.init_backend(*test_suit)
|
||||||
|
|
||||||
for param_name in parameters:
|
for param_name in parameters:
|
||||||
self.assert_weights_all_close(param_name, 100)
|
self.assert_weights_all_close(param_name, truncate_size)
|
||||||
self.close_engine_and_server()
|
|
||||||
|
if tie_word_embeddings:
|
||||||
|
self.assert_tie_word_embeddings(truncate_size)
|
||||||
|
|
||||||
|
self.clean_up()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
614
test/srt/test_update_weights_from_distributed.py
Normal file
614
test/srt/test_update_weights_from_distributed.py
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
"""Test distributed weight updates.
|
||||||
|
|
||||||
|
This test suite simulates a distributed training environment to ensure
|
||||||
|
correct weight synchronization. On rank 0, the instruct model represents
|
||||||
|
pre-training weights, and the base model represents post-training weights.
|
||||||
|
The base model's weights are broadcasted to other ranks using the online
|
||||||
|
weight update API.
|
||||||
|
|
||||||
|
On other ranks, an engine is initialized with the instruct model, and its
|
||||||
|
parameters are verified against the Hugging Face model. After updating
|
||||||
|
weights from the distributed system, post-training weights are loaded
|
||||||
|
and verified again to ensure consistency and accuracy across the
|
||||||
|
distributed setup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.srt.utils import init_custom_process_group
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
is_in_ci,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
from sglang.utils import terminate_process
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_params_close(params1, params2, error_msg):
|
||||||
|
"""Verify if two parameter arrays are close enough."""
|
||||||
|
try:
|
||||||
|
assert np.allclose(np.array(params1), np.array(params2)), error_msg
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Parameters not close for {error_msg}")
|
||||||
|
print("Params1:", np.array(params1))
|
||||||
|
print("Params2:", np.array(params2))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def verify_params_not_close(params1, params2, error_msg):
|
||||||
|
"""Verify if two parameter arrays are different enough."""
|
||||||
|
assert not np.allclose(np.array(params1), np.array(params2)), error_msg
|
||||||
|
|
||||||
|
|
||||||
|
def init_process(
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
state_dict_key_to_shape,
|
||||||
|
tp_size,
|
||||||
|
model_name,
|
||||||
|
backend,
|
||||||
|
checking_parameters,
|
||||||
|
tie_word_embeddings,
|
||||||
|
):
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
init_process_hf(
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
checking_parameters,
|
||||||
|
tie_word_embeddings,
|
||||||
|
state_dict_key_to_shape,
|
||||||
|
)
|
||||||
|
elif rank in [1, 2]:
|
||||||
|
init_process_sgl(
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
checking_parameters,
|
||||||
|
tie_word_embeddings,
|
||||||
|
state_dict_key_to_shape,
|
||||||
|
backend,
|
||||||
|
tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_process_hf(
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
checking_parameters,
|
||||||
|
tie_word_embeddings,
|
||||||
|
state_dict_key_to_shape,
|
||||||
|
):
|
||||||
|
# These two environment variables are very important
|
||||||
|
# to avoid unexpected behaviors of CUDA and NCCL.
|
||||||
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
|
|
||||||
|
# Load model and get parameters
|
||||||
|
hf_instruct_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype="bfloat16",
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
).to("cuda:0")
|
||||||
|
base_model_name = model_name.replace("-Instruct", "")
|
||||||
|
hf_base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model_name,
|
||||||
|
torch_dtype="bfloat16",
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
).to("cuda:0")
|
||||||
|
|
||||||
|
hf_instruct_params = []
|
||||||
|
hf_base_params = []
|
||||||
|
|
||||||
|
print(f"get parameter in hf instruct model and base model")
|
||||||
|
for parameter_name in checking_parameters:
|
||||||
|
hf_instruct_params.append(
|
||||||
|
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
|
||||||
|
.cpu()
|
||||||
|
.detach()
|
||||||
|
.float()
|
||||||
|
.numpy()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
hf_base_params.append(
|
||||||
|
hf_base_model.get_parameter(parameter_name)[:truncate_size]
|
||||||
|
.cpu()
|
||||||
|
.detach()
|
||||||
|
.float()
|
||||||
|
.numpy()
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
|
||||||
|
param_queue.put(("hf_instruct_params", hf_instruct_params))
|
||||||
|
param_queue.put(("hf_base_params", hf_base_params))
|
||||||
|
|
||||||
|
# Init weight update group for rank 0 (the training engine in RLHF).
|
||||||
|
print(f"rank {rank} world_size: {world_size} init custom process group")
|
||||||
|
group = init_custom_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
init_method="tcp://localhost:65500",
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
group_name="test_parameter_update_group",
|
||||||
|
)
|
||||||
|
dist.barrier(group=group, device_ids=[rank])
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time_begin_broadcast = time.time()
|
||||||
|
|
||||||
|
# The last parameter is lm_head.weight, which is tied
|
||||||
|
# with embed_tokens.weight. Actually, we only need
|
||||||
|
# to broadcast embed_tokens.weight once.
|
||||||
|
broadcast_parameters = list(state_dict_key_to_shape.keys())
|
||||||
|
if tie_word_embeddings:
|
||||||
|
broadcast_parameters.remove("lm_head.weight")
|
||||||
|
|
||||||
|
# Broadcast all the weights from the training
|
||||||
|
# engine to other ranks (inference engine).
|
||||||
|
for parameter_name in broadcast_parameters:
|
||||||
|
torch.distributed.broadcast(
|
||||||
|
hf_base_model.get_parameter(parameter_name),
|
||||||
|
src=0,
|
||||||
|
group=group,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time_end_broadcast = time.time()
|
||||||
|
|
||||||
|
# Measure the latency of broadcasting/weights update.
|
||||||
|
broadcast_time = time_end_broadcast - time_begin_broadcast
|
||||||
|
print(f"rank {rank} broadcast parameter time: {broadcast_time:.3f}s")
|
||||||
|
param_queue.put(("broadcast_time", broadcast_time))
|
||||||
|
|
||||||
|
# Delete the huggingface models to free up memory.
|
||||||
|
|
||||||
|
del hf_instruct_model
|
||||||
|
del hf_base_model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def init_process_sgl(
|
||||||
|
rank,
|
||||||
|
world_size,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
model_name,
|
||||||
|
checking_parameters,
|
||||||
|
tie_word_embeddings,
|
||||||
|
state_dict_key_to_shape,
|
||||||
|
backend,
|
||||||
|
tp_size,
|
||||||
|
):
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
base_gpu_id = 1 if rank == 1 else 1 + tp_size
|
||||||
|
if backend == "Engine":
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=model_name,
|
||||||
|
random_seed=42,
|
||||||
|
base_gpu_id=base_gpu_id,
|
||||||
|
tp_size=tp_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if rank == 1:
|
||||||
|
url = DEFAULT_URL_FOR_TEST
|
||||||
|
else:
|
||||||
|
url = DEFAULT_URL_FOR_TEST.replace("2157", "2159")
|
||||||
|
process = popen_launch_server(
|
||||||
|
model_name,
|
||||||
|
url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=(
|
||||||
|
"--base-gpu-id",
|
||||||
|
str(base_gpu_id),
|
||||||
|
"--tp-size",
|
||||||
|
str(tp_size),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if backend == "Engine":
|
||||||
|
print(f"rank {rank} init engine")
|
||||||
|
else:
|
||||||
|
print(f"rank {rank} init server on url: {url}")
|
||||||
|
|
||||||
|
# Get weights of instruct model, i.e. pre-training weights.
|
||||||
|
|
||||||
|
instruct_params = []
|
||||||
|
for parameter_name in checking_parameters:
|
||||||
|
instruct_params.append(
|
||||||
|
engine.get_weights_by_name(parameter_name, truncate_size)
|
||||||
|
if backend == "Engine"
|
||||||
|
else requests.get(
|
||||||
|
f"{url}/get_weights_by_name",
|
||||||
|
json={"name": parameter_name, "truncate_size": truncate_size},
|
||||||
|
).json()
|
||||||
|
)
|
||||||
|
|
||||||
|
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
|
||||||
|
|
||||||
|
# Init weight update group with the training engine.
|
||||||
|
|
||||||
|
if backend == "Engine":
|
||||||
|
engine.init_weights_update_group(
|
||||||
|
master_address="localhost",
|
||||||
|
master_port="65500",
|
||||||
|
rank_offset=base_gpu_id,
|
||||||
|
world_size=world_size,
|
||||||
|
group_name="test_parameter_update_group",
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
requests.post(
|
||||||
|
f"{url}/init_weights_update_group",
|
||||||
|
json={
|
||||||
|
"master_address": "localhost",
|
||||||
|
"master_port": "65500",
|
||||||
|
"rank_offset": base_gpu_id,
|
||||||
|
"world_size": world_size,
|
||||||
|
"group_name": "test_parameter_update_group",
|
||||||
|
"backend": "nccl",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time_begin_update = time.time()
|
||||||
|
|
||||||
|
# The last parameter is lm_head.weight, which is tied
|
||||||
|
# with embed_tokens.weight. Actually, we only need
|
||||||
|
# to update embed_tokens.weight once.
|
||||||
|
|
||||||
|
tie_word_embeddings = (
|
||||||
|
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
||||||
|
)
|
||||||
|
update_parameters = list(state_dict_key_to_shape.keys())
|
||||||
|
if tie_word_embeddings:
|
||||||
|
update_parameters.remove("lm_head.weight")
|
||||||
|
|
||||||
|
# Get weights from the training engine and update the inference engine.
|
||||||
|
|
||||||
|
for parameter_name in update_parameters:
|
||||||
|
if backend == "Engine":
|
||||||
|
engine.update_weights_from_distributed(
|
||||||
|
parameter_name,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
shape=state_dict_key_to_shape[parameter_name],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
requests.post(
|
||||||
|
f"{url}/update_weights_from_distributed",
|
||||||
|
json={
|
||||||
|
"name": parameter_name,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"shape": state_dict_key_to_shape[parameter_name],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time_end_update = time.time()
|
||||||
|
|
||||||
|
# Measure the latency of broadcast/weights update.
|
||||||
|
|
||||||
|
update_time = time_end_update - time_begin_update
|
||||||
|
print(
|
||||||
|
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
|
||||||
|
)
|
||||||
|
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
|
||||||
|
|
||||||
|
# Get the weights of post-training model after weights update for correctness check.
|
||||||
|
|
||||||
|
base_params = []
|
||||||
|
for parameter_name in checking_parameters:
|
||||||
|
if backend == "Engine":
|
||||||
|
base_params.append(
|
||||||
|
engine.get_weights_by_name(parameter_name, truncate_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
base_params.append(
|
||||||
|
requests.get(
|
||||||
|
f"{url}/get_weights_by_name",
|
||||||
|
json={
|
||||||
|
"name": parameter_name,
|
||||||
|
"truncate_size": truncate_size,
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
)
|
||||||
|
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
||||||
|
|
||||||
|
# Shutdown the engine or terminate the server process.
|
||||||
|
|
||||||
|
if backend == "Engine":
|
||||||
|
engine.shutdown()
|
||||||
|
else:
|
||||||
|
terminate_process(process)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tied_weights(params_list, message, should_be_tied):
|
||||||
|
for params in params_list:
|
||||||
|
if should_be_tied:
|
||||||
|
assert np.allclose(params[0], params[-1]), message
|
||||||
|
else:
|
||||||
|
assert not np.allclose(params[0], params[-1]), message
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_weights_from_distributed(
|
||||||
|
tp_size,
|
||||||
|
dp_size,
|
||||||
|
model_name,
|
||||||
|
backend,
|
||||||
|
state_dict_key_to_shape,
|
||||||
|
truncate_size,
|
||||||
|
checking_parameters,
|
||||||
|
):
|
||||||
|
tie_word_embeddings = (
|
||||||
|
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backend}"
|
||||||
|
)
|
||||||
|
param_queue = mp.Queue()
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
context = mp.spawn(
|
||||||
|
init_process,
|
||||||
|
args=(
|
||||||
|
1 + tp_size * dp_size,
|
||||||
|
param_queue,
|
||||||
|
truncate_size,
|
||||||
|
state_dict_key_to_shape,
|
||||||
|
tp_size,
|
||||||
|
model_name,
|
||||||
|
backend,
|
||||||
|
checking_parameters,
|
||||||
|
tie_word_embeddings,
|
||||||
|
),
|
||||||
|
nprocs=1 + dp_size,
|
||||||
|
join=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
while len(results) < 3 * (1 + dp_size):
|
||||||
|
try:
|
||||||
|
key, value = param_queue.get(timeout=5)
|
||||||
|
results[key] = value
|
||||||
|
except Exception as e:
|
||||||
|
if all(not p.is_alive() for p in context.processes):
|
||||||
|
break
|
||||||
|
|
||||||
|
context.join()
|
||||||
|
|
||||||
|
if len(results) != 3 * (1 + dp_size):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Expected {3 * (1 + dp_size)} parameters but got {len(results)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"hf_instruct": results.get("hf_instruct_params"),
|
||||||
|
"hf_base": results.get("hf_base_params"),
|
||||||
|
"sgl_dp_1_instruct": results.get("sgl_dp_1_instruct_params"),
|
||||||
|
"sgl_dp_1_base": results.get("sgl_dp_1_base_params"),
|
||||||
|
"broadcast_time": results.get("broadcast_time"),
|
||||||
|
"update_sgl_dp_1_time": results.get("update_sgl_dp_1_time"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if dp_size == 2:
|
||||||
|
dp2_params = {
|
||||||
|
"sgl_dp_2_instruct": results.get("sgl_dp_2_instruct_params"),
|
||||||
|
"sgl_dp_2_base": results.get("sgl_dp_2_base_params"),
|
||||||
|
"update_sgl_dp_2_time": results.get("update_sgl_dp_2_time"),
|
||||||
|
}
|
||||||
|
assert all(v is not None for v in dp2_params.values())
|
||||||
|
params.update(dp2_params)
|
||||||
|
|
||||||
|
# Check the correctness of weights update by verifying
|
||||||
|
# the weights of instruct model and base model.
|
||||||
|
|
||||||
|
for i in range(len(params["hf_instruct"])):
|
||||||
|
verify_params_close(
|
||||||
|
params["hf_instruct"][i],
|
||||||
|
params["sgl_dp_1_instruct"][i],
|
||||||
|
f"sgl_dp_1_instruct_params rank {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_params_close(
|
||||||
|
params["hf_base"][i],
|
||||||
|
params["sgl_dp_1_base"][i],
|
||||||
|
f"sgl_dp_1_base_params rank {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_params_not_close(
|
||||||
|
params["hf_instruct"][i],
|
||||||
|
params["hf_base"][i],
|
||||||
|
f"hf_instruct_params rank {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if dp_size == 2:
|
||||||
|
verify_params_close(
|
||||||
|
params["hf_base"][i],
|
||||||
|
params["sgl_dp_2_base"][i],
|
||||||
|
f"sgl_dp_2_base_params rank {i}",
|
||||||
|
)
|
||||||
|
verify_params_close(
|
||||||
|
params["hf_instruct"][i],
|
||||||
|
params["sgl_dp_2_instruct"][i],
|
||||||
|
f"sgl_dp_2_instruct_params rank {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(params["hf_instruct"]) == len(
|
||||||
|
params["hf_base"]
|
||||||
|
), "hf_instruct_params and hf_base_params have different lengths"
|
||||||
|
|
||||||
|
# Check if the weights of lm_head are tied with embed_tokens.
|
||||||
|
|
||||||
|
params_to_check = [
|
||||||
|
(
|
||||||
|
params["hf_instruct"],
|
||||||
|
"lm_head.weight is not tied with embed_tokens.weight",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
params["hf_base"],
|
||||||
|
"lm_head.weight is not tied with embed_tokens.weight",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
params["sgl_dp_1_instruct"],
|
||||||
|
"lm_head.weight is not tied with embed_tokens.weight",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
params["sgl_dp_1_base"],
|
||||||
|
"lm_head.weight is not tied with embed_tokens.weight",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
if dp_size == 2:
|
||||||
|
params_to_check.extend(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
params["sgl_dp_2_instruct"],
|
||||||
|
"lm_head.weight is not tied with embed_tokens.weight",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
params["sgl_dp_2_base"],
|
||||||
|
"lm_head.weight is not tied with embed_tokens.weight",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_tied_weights(
|
||||||
|
[params for params, _ in params_to_check],
|
||||||
|
(
|
||||||
|
"lm_head.weight is not tied with embed_tokens.weight"
|
||||||
|
if tie_word_embeddings
|
||||||
|
else "lm_head.weight is tied with embed_tokens.weight"
|
||||||
|
),
|
||||||
|
tie_word_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Time limit for broadcast and update on CI is 3 / 6
|
||||||
|
# On local H100, it's 1 / 2
|
||||||
|
|
||||||
|
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
|
||||||
|
|
||||||
|
assert (
|
||||||
|
params["broadcast_time"] < time_limit
|
||||||
|
), f"broadcast_time exceeds time limit {time_limit}s"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
params["update_sgl_dp_1_time"] < time_limit
|
||||||
|
), f"update_sgl_dp_one_time exceeds time limit {time_limit}s"
|
||||||
|
|
||||||
|
if dp_size == 2:
|
||||||
|
assert (
|
||||||
|
params["update_sgl_dp_2_time"] < time_limit
|
||||||
|
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
|
||||||
|
|
||||||
|
# Delete the context and close the parameter queue.
|
||||||
|
|
||||||
|
del context
|
||||||
|
param_queue.close()
|
||||||
|
param_queue.join_thread()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateWeightsFromDistributed(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_update_weights_from_distributed(self):
|
||||||
|
|
||||||
|
assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
|
||||||
|
# test_suits : tp, dp, model_name, backend
|
||||||
|
if is_in_ci():
|
||||||
|
test_suits = [
|
||||||
|
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
test_suits = [
|
||||||
|
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||||
|
(1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever"),
|
||||||
|
]
|
||||||
|
|
||||||
|
if torch.cuda.device_count() >= 4:
|
||||||
|
test_suits.extend(
|
||||||
|
[
|
||||||
|
(2, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||||
|
(1, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() >= 5:
|
||||||
|
test_suits.extend(
|
||||||
|
[
|
||||||
|
(2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||||
|
(2, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
model_state_dict_shapes = {}
|
||||||
|
test_models = [test_suit[2] for test_suit in test_suits]
|
||||||
|
|
||||||
|
for model_name in test_models:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name, torch_dtype="bfloat16"
|
||||||
|
).to("cuda:0")
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
model_state_dict_shapes[model_name] = {
|
||||||
|
key: state_dict[key].shape for key in state_dict_keys
|
||||||
|
}
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
truncate_size = 10
|
||||||
|
checking_parameters = [
|
||||||
|
"model.embed_tokens.weight",
|
||||||
|
"model.layers.0.input_layernorm.weight",
|
||||||
|
"model.layers.1.self_attn.q_proj.weight",
|
||||||
|
"model.layers.2.self_attn.k_proj.weight",
|
||||||
|
"model.layers.3.self_attn.v_proj.weight",
|
||||||
|
"model.layers.4.self_attn.o_proj.weight",
|
||||||
|
"model.layers.5.mlp.gate_proj.weight",
|
||||||
|
"model.layers.6.mlp.up_proj.weight",
|
||||||
|
"model.layers.7.mlp.down_proj.weight",
|
||||||
|
"model.layers.8.post_attention_layernorm.weight",
|
||||||
|
"model.norm.weight",
|
||||||
|
"lm_head.weight",
|
||||||
|
]
|
||||||
|
|
||||||
|
for tp_size, dp_size, model_name, backend in test_suits:
|
||||||
|
test_update_weights_from_distributed(
|
||||||
|
tp_size,
|
||||||
|
dp_size,
|
||||||
|
model_name,
|
||||||
|
backend,
|
||||||
|
model_state_dict_shapes[model_name],
|
||||||
|
truncate_size,
|
||||||
|
checking_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user