Online weight updates from torch.distributed (#2279)
This commit is contained in:
614
test/srt/test_update_weights_from_distributed.py
Normal file
614
test/srt/test_update_weights_from_distributed.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""Test distributed weight updates.
|
||||
|
||||
This test suite simulates a distributed training environment to ensure
|
||||
correct weight synchronization. On rank 0, the instruct model represents
|
||||
pre-training weights, and the base model represents post-training weights.
|
||||
The base model's weights are broadcasted to other ranks using the online
|
||||
weight update API.
|
||||
|
||||
On other ranks, an engine is initialized with the instruct model, and its
|
||||
parameters are verified against the Hugging Face model. After updating
|
||||
weights from the distributed system, post-training weights are loaded
|
||||
and verified again to ensure consistency and accuracy across the
|
||||
distributed setup.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.utils import init_custom_process_group
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
from sglang.utils import terminate_process
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def verify_params_close(params1, params2, error_msg):
|
||||
"""Verify if two parameter arrays are close enough."""
|
||||
try:
|
||||
assert np.allclose(np.array(params1), np.array(params2)), error_msg
|
||||
except Exception as e:
|
||||
print(f"Parameters not close for {error_msg}")
|
||||
print("Params1:", np.array(params1))
|
||||
print("Params2:", np.array(params2))
|
||||
raise e
|
||||
|
||||
|
||||
def verify_params_not_close(params1, params2, error_msg):
|
||||
"""Verify if two parameter arrays are different enough."""
|
||||
assert not np.allclose(np.array(params1), np.array(params2)), error_msg
|
||||
|
||||
|
||||
def init_process(
|
||||
rank,
|
||||
world_size,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
state_dict_key_to_shape,
|
||||
tp_size,
|
||||
model_name,
|
||||
backend,
|
||||
checking_parameters,
|
||||
tie_word_embeddings,
|
||||
):
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
if rank == 0:
|
||||
init_process_hf(
|
||||
rank,
|
||||
world_size,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
checking_parameters,
|
||||
tie_word_embeddings,
|
||||
state_dict_key_to_shape,
|
||||
)
|
||||
elif rank in [1, 2]:
|
||||
init_process_sgl(
|
||||
rank,
|
||||
world_size,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
checking_parameters,
|
||||
tie_word_embeddings,
|
||||
state_dict_key_to_shape,
|
||||
backend,
|
||||
tp_size,
|
||||
)
|
||||
|
||||
|
||||
def init_process_hf(
|
||||
rank,
|
||||
world_size,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
checking_parameters,
|
||||
tie_word_embeddings,
|
||||
state_dict_key_to_shape,
|
||||
):
|
||||
# These two environment variables are very important
|
||||
# to avoid unexpected behaviors of CUDA and NCCL.
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
|
||||
# Load model and get parameters
|
||||
hf_instruct_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="bfloat16",
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
).to("cuda:0")
|
||||
base_model_name = model_name.replace("-Instruct", "")
|
||||
hf_base_model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name,
|
||||
torch_dtype="bfloat16",
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
).to("cuda:0")
|
||||
|
||||
hf_instruct_params = []
|
||||
hf_base_params = []
|
||||
|
||||
print(f"get parameter in hf instruct model and base model")
|
||||
for parameter_name in checking_parameters:
|
||||
hf_instruct_params.append(
|
||||
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
|
||||
.cpu()
|
||||
.detach()
|
||||
.float()
|
||||
.numpy()
|
||||
.tolist()
|
||||
)
|
||||
hf_base_params.append(
|
||||
hf_base_model.get_parameter(parameter_name)[:truncate_size]
|
||||
.cpu()
|
||||
.detach()
|
||||
.float()
|
||||
.numpy()
|
||||
.tolist()
|
||||
)
|
||||
|
||||
param_queue.put(("hf_instruct_params", hf_instruct_params))
|
||||
param_queue.put(("hf_base_params", hf_base_params))
|
||||
|
||||
# Init weight update group for rank 0 (the training engine in RLHF).
|
||||
print(f"rank {rank} world_size: {world_size} init custom process group")
|
||||
group = init_custom_process_group(
|
||||
backend="nccl",
|
||||
init_method="tcp://localhost:65500",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
group_name="test_parameter_update_group",
|
||||
)
|
||||
dist.barrier(group=group, device_ids=[rank])
|
||||
torch.cuda.synchronize()
|
||||
time_begin_broadcast = time.time()
|
||||
|
||||
# The last parameter is lm_head.weight, which is tied
|
||||
# with embed_tokens.weight. Actually, we only need
|
||||
# to broadcast embed_tokens.weight once.
|
||||
broadcast_parameters = list(state_dict_key_to_shape.keys())
|
||||
if tie_word_embeddings:
|
||||
broadcast_parameters.remove("lm_head.weight")
|
||||
|
||||
# Broadcast all the weights from the training
|
||||
# engine to other ranks (inference engine).
|
||||
for parameter_name in broadcast_parameters:
|
||||
torch.distributed.broadcast(
|
||||
hf_base_model.get_parameter(parameter_name),
|
||||
src=0,
|
||||
group=group,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
time_end_broadcast = time.time()
|
||||
|
||||
# Measure the latency of broadcasting/weights update.
|
||||
broadcast_time = time_end_broadcast - time_begin_broadcast
|
||||
print(f"rank {rank} broadcast parameter time: {broadcast_time:.3f}s")
|
||||
param_queue.put(("broadcast_time", broadcast_time))
|
||||
|
||||
# Delete the huggingface models to free up memory.
|
||||
|
||||
del hf_instruct_model
|
||||
del hf_base_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def init_process_sgl(
|
||||
rank,
|
||||
world_size,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
model_name,
|
||||
checking_parameters,
|
||||
tie_word_embeddings,
|
||||
state_dict_key_to_shape,
|
||||
backend,
|
||||
tp_size,
|
||||
):
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.synchronize()
|
||||
base_gpu_id = 1 if rank == 1 else 1 + tp_size
|
||||
if backend == "Engine":
|
||||
engine = sgl.Engine(
|
||||
model_path=model_name,
|
||||
random_seed=42,
|
||||
base_gpu_id=base_gpu_id,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
else:
|
||||
if rank == 1:
|
||||
url = DEFAULT_URL_FOR_TEST
|
||||
else:
|
||||
url = DEFAULT_URL_FOR_TEST.replace("2157", "2159")
|
||||
process = popen_launch_server(
|
||||
model_name,
|
||||
url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=(
|
||||
"--base-gpu-id",
|
||||
str(base_gpu_id),
|
||||
"--tp-size",
|
||||
str(tp_size),
|
||||
),
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
if backend == "Engine":
|
||||
print(f"rank {rank} init engine")
|
||||
else:
|
||||
print(f"rank {rank} init server on url: {url}")
|
||||
|
||||
# Get weights of instruct model, i.e. pre-training weights.
|
||||
|
||||
instruct_params = []
|
||||
for parameter_name in checking_parameters:
|
||||
instruct_params.append(
|
||||
engine.get_weights_by_name(parameter_name, truncate_size)
|
||||
if backend == "Engine"
|
||||
else requests.get(
|
||||
f"{url}/get_weights_by_name",
|
||||
json={"name": parameter_name, "truncate_size": truncate_size},
|
||||
).json()
|
||||
)
|
||||
|
||||
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
|
||||
|
||||
# Init weight update group with the training engine.
|
||||
|
||||
if backend == "Engine":
|
||||
engine.init_weights_update_group(
|
||||
master_address="localhost",
|
||||
master_port="65500",
|
||||
rank_offset=base_gpu_id,
|
||||
world_size=world_size,
|
||||
group_name="test_parameter_update_group",
|
||||
backend="nccl",
|
||||
)
|
||||
else:
|
||||
requests.post(
|
||||
f"{url}/init_weights_update_group",
|
||||
json={
|
||||
"master_address": "localhost",
|
||||
"master_port": "65500",
|
||||
"rank_offset": base_gpu_id,
|
||||
"world_size": world_size,
|
||||
"group_name": "test_parameter_update_group",
|
||||
"backend": "nccl",
|
||||
},
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time_begin_update = time.time()
|
||||
|
||||
# The last parameter is lm_head.weight, which is tied
|
||||
# with embed_tokens.weight. Actually, we only need
|
||||
# to update embed_tokens.weight once.
|
||||
|
||||
tie_word_embeddings = (
|
||||
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
||||
)
|
||||
update_parameters = list(state_dict_key_to_shape.keys())
|
||||
if tie_word_embeddings:
|
||||
update_parameters.remove("lm_head.weight")
|
||||
|
||||
# Get weights from the training engine and update the inference engine.
|
||||
|
||||
for parameter_name in update_parameters:
|
||||
if backend == "Engine":
|
||||
engine.update_weights_from_distributed(
|
||||
parameter_name,
|
||||
dtype=torch.bfloat16,
|
||||
shape=state_dict_key_to_shape[parameter_name],
|
||||
)
|
||||
else:
|
||||
requests.post(
|
||||
f"{url}/update_weights_from_distributed",
|
||||
json={
|
||||
"name": parameter_name,
|
||||
"dtype": "bfloat16",
|
||||
"shape": state_dict_key_to_shape[parameter_name],
|
||||
},
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
time_end_update = time.time()
|
||||
|
||||
# Measure the latency of broadcast/weights update.
|
||||
|
||||
update_time = time_end_update - time_begin_update
|
||||
print(
|
||||
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
|
||||
)
|
||||
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
|
||||
|
||||
# Get the weights of post-training model after weights update for correctness check.
|
||||
|
||||
base_params = []
|
||||
for parameter_name in checking_parameters:
|
||||
if backend == "Engine":
|
||||
base_params.append(
|
||||
engine.get_weights_by_name(parameter_name, truncate_size)
|
||||
)
|
||||
else:
|
||||
base_params.append(
|
||||
requests.get(
|
||||
f"{url}/get_weights_by_name",
|
||||
json={
|
||||
"name": parameter_name,
|
||||
"truncate_size": truncate_size,
|
||||
},
|
||||
).json()
|
||||
)
|
||||
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
|
||||
|
||||
# Shutdown the engine or terminate the server process.
|
||||
|
||||
if backend == "Engine":
|
||||
engine.shutdown()
|
||||
else:
|
||||
terminate_process(process)
|
||||
|
||||
|
||||
def assert_tied_weights(params_list, message, should_be_tied):
|
||||
for params in params_list:
|
||||
if should_be_tied:
|
||||
assert np.allclose(params[0], params[-1]), message
|
||||
else:
|
||||
assert not np.allclose(params[0], params[-1]), message
|
||||
|
||||
|
||||
def test_update_weights_from_distributed(
|
||||
tp_size,
|
||||
dp_size,
|
||||
model_name,
|
||||
backend,
|
||||
state_dict_key_to_shape,
|
||||
truncate_size,
|
||||
checking_parameters,
|
||||
):
|
||||
tie_word_embeddings = (
|
||||
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
|
||||
)
|
||||
|
||||
print(
|
||||
f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backend}"
|
||||
)
|
||||
param_queue = mp.Queue()
|
||||
results = {}
|
||||
|
||||
context = mp.spawn(
|
||||
init_process,
|
||||
args=(
|
||||
1 + tp_size * dp_size,
|
||||
param_queue,
|
||||
truncate_size,
|
||||
state_dict_key_to_shape,
|
||||
tp_size,
|
||||
model_name,
|
||||
backend,
|
||||
checking_parameters,
|
||||
tie_word_embeddings,
|
||||
),
|
||||
nprocs=1 + dp_size,
|
||||
join=False,
|
||||
)
|
||||
|
||||
while len(results) < 3 * (1 + dp_size):
|
||||
try:
|
||||
key, value = param_queue.get(timeout=5)
|
||||
results[key] = value
|
||||
except Exception as e:
|
||||
if all(not p.is_alive() for p in context.processes):
|
||||
break
|
||||
|
||||
context.join()
|
||||
|
||||
if len(results) != 3 * (1 + dp_size):
|
||||
raise RuntimeError(
|
||||
f"Expected {3 * (1 + dp_size)} parameters but got {len(results)}"
|
||||
)
|
||||
|
||||
params = {
|
||||
"hf_instruct": results.get("hf_instruct_params"),
|
||||
"hf_base": results.get("hf_base_params"),
|
||||
"sgl_dp_1_instruct": results.get("sgl_dp_1_instruct_params"),
|
||||
"sgl_dp_1_base": results.get("sgl_dp_1_base_params"),
|
||||
"broadcast_time": results.get("broadcast_time"),
|
||||
"update_sgl_dp_1_time": results.get("update_sgl_dp_1_time"),
|
||||
}
|
||||
|
||||
if dp_size == 2:
|
||||
dp2_params = {
|
||||
"sgl_dp_2_instruct": results.get("sgl_dp_2_instruct_params"),
|
||||
"sgl_dp_2_base": results.get("sgl_dp_2_base_params"),
|
||||
"update_sgl_dp_2_time": results.get("update_sgl_dp_2_time"),
|
||||
}
|
||||
assert all(v is not None for v in dp2_params.values())
|
||||
params.update(dp2_params)
|
||||
|
||||
# Check the correctness of weights update by verifying
|
||||
# the weights of instruct model and base model.
|
||||
|
||||
for i in range(len(params["hf_instruct"])):
|
||||
verify_params_close(
|
||||
params["hf_instruct"][i],
|
||||
params["sgl_dp_1_instruct"][i],
|
||||
f"sgl_dp_1_instruct_params rank {i}",
|
||||
)
|
||||
|
||||
verify_params_close(
|
||||
params["hf_base"][i],
|
||||
params["sgl_dp_1_base"][i],
|
||||
f"sgl_dp_1_base_params rank {i}",
|
||||
)
|
||||
|
||||
verify_params_not_close(
|
||||
params["hf_instruct"][i],
|
||||
params["hf_base"][i],
|
||||
f"hf_instruct_params rank {i}",
|
||||
)
|
||||
|
||||
if dp_size == 2:
|
||||
verify_params_close(
|
||||
params["hf_base"][i],
|
||||
params["sgl_dp_2_base"][i],
|
||||
f"sgl_dp_2_base_params rank {i}",
|
||||
)
|
||||
verify_params_close(
|
||||
params["hf_instruct"][i],
|
||||
params["sgl_dp_2_instruct"][i],
|
||||
f"sgl_dp_2_instruct_params rank {i}",
|
||||
)
|
||||
|
||||
assert len(params["hf_instruct"]) == len(
|
||||
params["hf_base"]
|
||||
), "hf_instruct_params and hf_base_params have different lengths"
|
||||
|
||||
# Check if the weights of lm_head are tied with embed_tokens.
|
||||
|
||||
params_to_check = [
|
||||
(
|
||||
params["hf_instruct"],
|
||||
"lm_head.weight is not tied with embed_tokens.weight",
|
||||
),
|
||||
(
|
||||
params["hf_base"],
|
||||
"lm_head.weight is not tied with embed_tokens.weight",
|
||||
),
|
||||
(
|
||||
params["sgl_dp_1_instruct"],
|
||||
"lm_head.weight is not tied with embed_tokens.weight",
|
||||
),
|
||||
(
|
||||
params["sgl_dp_1_base"],
|
||||
"lm_head.weight is not tied with embed_tokens.weight",
|
||||
),
|
||||
]
|
||||
|
||||
if dp_size == 2:
|
||||
params_to_check.extend(
|
||||
[
|
||||
(
|
||||
params["sgl_dp_2_instruct"],
|
||||
"lm_head.weight is not tied with embed_tokens.weight",
|
||||
),
|
||||
(
|
||||
params["sgl_dp_2_base"],
|
||||
"lm_head.weight is not tied with embed_tokens.weight",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert_tied_weights(
|
||||
[params for params, _ in params_to_check],
|
||||
(
|
||||
"lm_head.weight is not tied with embed_tokens.weight"
|
||||
if tie_word_embeddings
|
||||
else "lm_head.weight is tied with embed_tokens.weight"
|
||||
),
|
||||
tie_word_embeddings,
|
||||
)
|
||||
|
||||
# Time limit for broadcast and update on CI is 3 / 6
|
||||
# On local H100, it's 1 / 2
|
||||
|
||||
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
|
||||
|
||||
assert (
|
||||
params["broadcast_time"] < time_limit
|
||||
), f"broadcast_time exceeds time limit {time_limit}s"
|
||||
|
||||
assert (
|
||||
params["update_sgl_dp_1_time"] < time_limit
|
||||
), f"update_sgl_dp_one_time exceeds time limit {time_limit}s"
|
||||
|
||||
if dp_size == 2:
|
||||
assert (
|
||||
params["update_sgl_dp_2_time"] < time_limit
|
||||
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
|
||||
|
||||
# Delete the context and close the parameter queue.
|
||||
|
||||
del context
|
||||
param_queue.close()
|
||||
param_queue.join_thread()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class TestUpdateWeightsFromDistributed(unittest.TestCase):
|
||||
|
||||
def test_update_weights_from_distributed(self):
|
||||
|
||||
assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
|
||||
# test_suits : tp, dp, model_name, backend
|
||||
if is_in_ci():
|
||||
test_suits = [
|
||||
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||
]
|
||||
else:
|
||||
test_suits = [
|
||||
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||
(1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever"),
|
||||
]
|
||||
|
||||
if torch.cuda.device_count() >= 4:
|
||||
test_suits.extend(
|
||||
[
|
||||
(2, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||
(1, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
|
||||
]
|
||||
)
|
||||
|
||||
if torch.cuda.device_count() >= 5:
|
||||
test_suits.extend(
|
||||
[
|
||||
(2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
|
||||
(2, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
|
||||
]
|
||||
)
|
||||
|
||||
model_state_dict_shapes = {}
|
||||
test_models = [test_suit[2] for test_suit in test_suits]
|
||||
|
||||
for model_name in test_models:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype="bfloat16"
|
||||
).to("cuda:0")
|
||||
state_dict = model.state_dict()
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
model_state_dict_shapes[model_name] = {
|
||||
key: state_dict[key].shape for key in state_dict_keys
|
||||
}
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
truncate_size = 10
|
||||
checking_parameters = [
|
||||
"model.embed_tokens.weight",
|
||||
"model.layers.0.input_layernorm.weight",
|
||||
"model.layers.1.self_attn.q_proj.weight",
|
||||
"model.layers.2.self_attn.k_proj.weight",
|
||||
"model.layers.3.self_attn.v_proj.weight",
|
||||
"model.layers.4.self_attn.o_proj.weight",
|
||||
"model.layers.5.mlp.gate_proj.weight",
|
||||
"model.layers.6.mlp.up_proj.weight",
|
||||
"model.layers.7.mlp.down_proj.weight",
|
||||
"model.layers.8.post_attention_layernorm.weight",
|
||||
"model.norm.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
|
||||
for tp_size, dp_size, model_name, backend in test_suits:
|
||||
test_update_weights_from_distributed(
|
||||
tp_size,
|
||||
dp_size,
|
||||
model_name,
|
||||
backend,
|
||||
model_state_dict_shapes[model_name],
|
||||
truncate_size,
|
||||
checking_parameters,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user