sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,371 @@
import json
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
import sglang as sgl
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
###############################################################################
# Engine Mode Tests (Single-configuration)
###############################################################################
class TestEngineUpdateWeightsFromDisk(CustomTestCase):
def setUp(self):
self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Initialize the engine in offline (direct) mode.
self.engine = sgl.Engine(model_path=self.model)
def tearDown(self):
self.engine.shutdown()
def run_decode(self):
prompts = ["The capital of France is"]
sampling_params = {"temperature": 0, "max_new_tokens": 32}
outputs = self.engine.generate(prompts, sampling_params)
print("=" * 100)
print(
f"[Engine Mode] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
)
return outputs[0]["text"]
def run_update_weights(self, model_path):
ret = self.engine.update_weights_from_disk(model_path)
print(json.dumps(ret))
return ret
def test_update_weights(self):
origin_response = self.run_decode()
# Update weights: use new model (remove "-Instruct")
new_model_path = self.model.replace("-Instruct", "")
ret = self.run_update_weights(new_model_path)
self.assertTrue(ret[0]) # ret is a tuple; index 0 holds the success flag
updated_response = self.run_decode()
self.assertNotEqual(origin_response[:32], updated_response[:32])
# Revert back to original weights
ret = self.run_update_weights(self.model)
self.assertTrue(ret[0])
reverted_response = self.run_decode()
self.assertEqual(origin_response[:32], reverted_response[:32])
def test_update_weights_unexist_model(self):
origin_response = self.run_decode()
new_model_path = self.model.replace("-Instruct", "wrong")
ret = self.run_update_weights(new_model_path)
self.assertFalse(ret[0])
updated_response = self.run_decode()
self.assertEqual(origin_response[:32], updated_response[:32])
###############################################################################
# HTTP Server Mode Tests (Single-configuration)
###############################################################################
class TestServerUpdateWeightsFromDisk(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 32},
},
)
print("=" * 100)
print(f"[Server Mode] Generated text: {response.json()['text']}")
return response.json()["text"]
def get_model_info(self):
response = requests.get(self.base_url + "/get_model_info")
model_path = response.json()["model_path"]
print(json.dumps(response.json()))
return model_path
def run_update_weights(self, model_path):
response = requests.post(
self.base_url + "/update_weights_from_disk",
json={"model_path": model_path},
)
ret = response.json()
print(json.dumps(ret))
return ret
def test_update_weights(self):
origin_model_path = self.get_model_info()
print(f"[Server Mode] origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
ret = self.run_update_weights(new_model_path)
self.assertTrue(ret["success"])
updated_model_path = self.get_model_info()
print(f"[Server Mode] updated_model_path: {updated_model_path}")
self.assertEqual(updated_model_path, new_model_path)
self.assertNotEqual(updated_model_path, origin_model_path)
updated_response = self.run_decode()
self.assertNotEqual(origin_response[:32], updated_response[:32])
ret = self.run_update_weights(origin_model_path)
self.assertTrue(ret["success"])
updated_model_path = self.get_model_info()
self.assertEqual(updated_model_path, origin_model_path)
updated_response = self.run_decode()
self.assertEqual(origin_response[:32], updated_response[:32])
def test_update_weights_unexist_model(self):
origin_model_path = self.get_model_info()
print(f"[Server Mode] origin_model_path: {origin_model_path}")
origin_response = self.run_decode()
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong")
ret = self.run_update_weights(new_model_path)
self.assertFalse(ret["success"])
updated_model_path = self.get_model_info()
print(f"[Server Mode] updated_model_path: {updated_model_path}")
self.assertEqual(updated_model_path, origin_model_path)
updated_response = self.run_decode()
self.assertEqual(origin_response[:32], updated_response[:32])
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--max-running-requests", 8],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, max_new_tokens=32):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
},
)
return response.json()
def get_model_info(self):
response = requests.get(self.base_url + "/get_model_info")
model_path = response.json()["model_path"]
print(json.dumps(response.json()))
return model_path
def run_update_weights(self, model_path, abort_all_requests=False):
response = requests.post(
self.base_url + "/update_weights_from_disk",
json={
"model_path": model_path,
"abort_all_requests": abort_all_requests,
},
)
ret = response.json()
print(json.dumps(ret))
return ret
def test_update_weights_abort_all_requests(self):
origin_model_path = self.get_model_info()
print(f"[Server Mode] origin_model_path: {origin_model_path}")
num_requests = 32
with ThreadPoolExecutor(num_requests) as executor:
futures = [
executor.submit(self.run_decode, 16000) for _ in range(num_requests)
]
# ensure the decode has been started
time.sleep(2)
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
ret = self.run_update_weights(new_model_path, abort_all_requests=True)
self.assertTrue(ret["success"])
for future in as_completed(futures):
self.assertEqual(
future.result()["meta_info"]["finish_reason"]["type"], "abort"
)
updated_model_path = self.get_model_info()
print(f"[Server Mode] updated_model_path: {updated_model_path}")
self.assertEqual(updated_model_path, new_model_path)
self.assertNotEqual(updated_model_path, origin_model_path)
###############################################################################
# Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci:
# - In a CI environment: randomly select one mode (Engine or Server) and test only with tp=1, dp=1.
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
# with tp and dp ranging from 1 to 2.
###############################################################################
class TestUpdateWeightsFromDiskParameterized(CustomTestCase):
def run_common_test(self, mode, tp, dp):
"""
Common test procedure for update_weights_from_disk.
For Engine mode, we instantiate the engine with tp_size=tp.
For Server mode, we launch the server with additional arguments for tp (dp is not used in server launch here).
"""
if mode == "Engine":
# Instantiate engine with additional parameter tp_size.
print(f"[Parameterized Engine] Testing with tp={tp}, dp={dp}")
engine = sgl.Engine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
random_seed=42,
tp_size=tp,
# dp parameter is not explicitly used in this API.
)
try:
origin_response = self._engine_update_weights_test(engine)
finally:
engine.shutdown()
elif mode == "Server":
print(f"[Parameterized Server] Testing with tp={tp}, dp={dp}")
# Pass additional arguments to launch the server.
base_args = ["--tp-size", str(tp)]
process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=base_args,
)
try:
origin_response = self._server_update_weights_test(DEFAULT_URL_FOR_TEST)
finally:
kill_process_tree(process.pid)
else:
raise ValueError(f"Unknown mode: {mode}")
def _engine_update_weights_test(self, engine):
# Run the update weights test on the given engine instance.
def run_decode():
prompts = ["The capital of France is"]
sampling_params = {"temperature": 0, "max_new_tokens": 32}
outputs = engine.generate(prompts, sampling_params)
print("=" * 100)
print(
f"[Parameterized Engine] Prompt: {prompts[0]}\nGenerated text: {outputs[0]['text']}"
)
return outputs[0]["text"]
def run_update_weights(model_path):
ret = engine.update_weights_from_disk(model_path)
print(json.dumps(ret))
return ret
origin_response = run_decode()
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
ret = run_update_weights(new_model_path)
self.assertTrue(ret[0])
updated_response = run_decode()
self.assertNotEqual(origin_response[:32], updated_response[:32])
ret = run_update_weights(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
self.assertTrue(ret[0])
reverted_response = run_decode()
self.assertEqual(origin_response[:32], reverted_response[:32])
return origin_response
def _server_update_weights_test(self, base_url):
def run_decode():
response = requests.post(
base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 32},
},
)
print("=" * 100)
print(f"[Parameterized Server] Generated text: {response.json()['text']}")
return response.json()["text"]
def get_model_info():
response = requests.get(base_url + "/get_model_info")
model_path = response.json()["model_path"]
print(json.dumps(response.json()))
return model_path
def run_update_weights(model_path):
response = requests.post(
base_url + "/update_weights_from_disk",
json={"model_path": model_path},
)
ret = response.json()
print(json.dumps(ret))
return ret
origin_model_path = get_model_info()
origin_response = run_decode()
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
ret = run_update_weights(new_model_path)
self.assertTrue(ret["success"])
updated_model_path = get_model_info()
self.assertEqual(updated_model_path, new_model_path)
self.assertNotEqual(updated_model_path, origin_model_path)
updated_response = run_decode()
self.assertNotEqual(origin_response[:32], updated_response[:32])
ret = run_update_weights(origin_model_path)
self.assertTrue(ret["success"])
updated_model_path = get_model_info()
self.assertEqual(updated_model_path, origin_model_path)
reverted_response = run_decode()
self.assertEqual(origin_response[:32], reverted_response[:32])
return origin_response
def test_parameterized_update_weights(self):
if is_in_ci():
# In CI, choose one random mode (Engine or Server) with tp=1, dp=1.
mode = random.choice(["Engine", "Server"])
test_suits = [(1, 1, mode)]
else:
# Otherwise, test both modes and enumerate tp,dp combinations from 1 to 2.
test_suits = []
for mode in ["Engine", "Server"]:
for tp in [1, 2]:
for dp in [1, 2]:
test_suits.append((tp, dp, mode))
for tp, dp, mode in test_suits:
with self.subTest(mode=mode, tp=tp, dp=dp):
self.run_common_test(mode, tp, dp)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,615 @@
"""Test distributed weight updates.
This test suite simulates a distributed training environment to ensure
correct weight synchronization. On rank 0, the instruct model represents
pre-training weights, and the base model represents post-training weights.
The base model's weights are broadcasted to other ranks using the online
weight update API.
On other ranks, an engine is initialized with the instruct model, and its
parameters are verified against the Hugging Face model. After updating
weights from the distributed system, post-training weights are loaded
and verified again to ensure consistency and accuracy across the
distributed setup.
"""
import gc
import os
import random
import time
import unittest
import numpy as np
import requests
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.srt.utils import init_custom_process_group
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
from sglang.utils import terminate_process
mp.set_start_method("spawn", force=True)
def verify_params_close(params1, params2, error_msg):
"""Verify if two parameter arrays are close enough."""
try:
assert np.allclose(np.array(params1), np.array(params2)), error_msg
except Exception as e:
print(f"Parameters not close for {error_msg}")
print("Params1:", np.array(params1))
print("Params2:", np.array(params2))
raise e
def verify_params_not_close(params1, params2, error_msg):
"""Verify if two parameter arrays are different enough."""
assert not np.allclose(np.array(params1), np.array(params2)), error_msg
def init_process(
rank,
world_size,
param_queue,
truncate_size,
state_dict_key_to_shape,
tp_size,
model_name,
backend,
checking_parameters,
tie_word_embeddings,
):
torch.cuda.set_device(rank)
if rank == 0:
init_process_hf(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
)
elif rank in [1, 2]:
init_process_sgl(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
backend,
tp_size,
)
def init_process_hf(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
):
# These two environment variables are very important
# to avoid unexpected behaviors of CUDA and NCCL.
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
# Load model and get parameters
hf_instruct_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="bfloat16",
tie_word_embeddings=tie_word_embeddings,
).to("cuda:0")
base_model_name = model_name.replace("-Instruct", "")
hf_base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype="bfloat16",
tie_word_embeddings=tie_word_embeddings,
).to("cuda:0")
hf_instruct_params = []
hf_base_params = []
print("[hf] get parameter in hf instruct model and base model")
for parameter_name in checking_parameters:
hf_instruct_params.append(
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
.cpu()
.detach()
.float()
.numpy()
.tolist()
)
hf_base_params.append(
hf_base_model.get_parameter(parameter_name)[:truncate_size]
.cpu()
.detach()
.float()
.numpy()
.tolist()
)
param_queue.put(("hf_instruct_params", hf_instruct_params))
param_queue.put(("hf_base_params", hf_base_params))
# Init weight update group for rank 0 (the training engine in RLHF).
port = 60000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
init_method = f"tcp://localhost:{port}"
print(f"[hf] {rank=} {world_size=} init custom process group. {init_method=}")
group = init_custom_process_group(
backend="nccl",
init_method=init_method,
world_size=world_size,
rank=rank,
group_name="test_parameter_update_group",
)
torch.cuda.synchronize()
time_begin_broadcast = time.perf_counter()
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to broadcast embed_tokens.weight once.
broadcast_parameters = list(state_dict_key_to_shape.keys())
if tie_word_embeddings:
broadcast_parameters.remove("lm_head.weight")
# Broadcast all the weights from the training
# engine to other ranks (inference engine).
for parameter_name in broadcast_parameters:
torch.distributed.broadcast(
hf_base_model.get_parameter(parameter_name),
src=0,
group=group,
)
torch.cuda.synchronize()
time_end_broadcast = time.perf_counter()
# Measure the latency of broadcasting/weights update.
broadcast_time = time_end_broadcast - time_begin_broadcast
print(f"[hf] {rank=} {broadcast_time=:.3f}s")
param_queue.put(("broadcast_time", broadcast_time))
# Delete the huggingface models to free up memory.
del hf_instruct_model
del hf_base_model
gc.collect()
torch.cuda.empty_cache()
def init_process_sgl(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
backend,
tp_size,
):
torch.cuda.set_device(rank)
torch.cuda.synchronize()
base_gpu_id = 1 if rank == 1 else 1 + tp_size
if backend == "Engine":
print(f"[sgl] rank {rank} init engine")
engine = sgl.Engine(
model_path=model_name,
base_gpu_id=base_gpu_id,
tp_size=tp_size,
cuda_graph_max_bs=2,
)
else:
if rank == 1:
url = DEFAULT_URL_FOR_TEST
else:
host, _, port = DEFAULT_URL_FOR_TEST.rpartition(":")
url = ":".join([host, str(int(port) + 10000)])
print(f"[sgl] rank {rank} init server on url: {url}")
process = popen_launch_server(
model_name,
url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--base-gpu-id",
str(base_gpu_id),
"--tp-size",
str(tp_size),
"--cuda-graph-max-bs",
2,
),
)
torch.cuda.synchronize()
# Get weights of instruct model, i.e. pre-training weights.
instruct_params = []
for parameter_name in checking_parameters:
instruct_params.append(
engine.get_weights_by_name(parameter_name, truncate_size)
if backend == "Engine"
else requests.get(
f"{url}/get_weights_by_name",
json={"name": parameter_name, "truncate_size": truncate_size},
).json()
)
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
port = 60000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 100
# Init weight update group with the training engine.
if backend == "Engine":
engine.init_weights_update_group(
master_address="localhost",
master_port=str(port),
rank_offset=base_gpu_id,
world_size=world_size,
group_name="test_parameter_update_group",
backend="nccl",
)
else:
requests.post(
f"{url}/init_weights_update_group",
json={
"master_address": "localhost",
"master_port": str(port),
"rank_offset": base_gpu_id,
"world_size": world_size,
"group_name": "test_parameter_update_group",
"backend": "nccl",
},
)
torch.cuda.synchronize()
time_begin_update = time.perf_counter()
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to update embed_tokens.weight once.
tie_word_embeddings = (
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
)
update_parameters = list(state_dict_key_to_shape.keys())
if tie_word_embeddings:
update_parameters.remove("lm_head.weight")
# Get weights from the training engine and update the inference engine.
names = [parameter_name for parameter_name in update_parameters]
dtypes = [torch.bfloat16 if backend == "Engine" else "bfloat16"] * len(names)
shapes = [state_dict_key_to_shape[parameter_name] for parameter_name in names]
if backend == "Engine":
engine.update_weights_from_distributed(
names,
dtypes=dtypes,
shapes=shapes,
group_name="test_parameter_update_group",
)
else:
requests.post(
f"{url}/update_weights_from_distributed",
json={
"names": names,
"dtypes": dtypes,
"shapes": shapes,
"group_name": "test_parameter_update_group",
},
)
torch.cuda.synchronize()
time_end_update = time.perf_counter()
# Measure the latency of broadcast/weights update.
update_time = time_end_update - time_begin_update
print(
f"[sgl] fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
)
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
# Get the weights of post-training model after weights update for correctness check.
base_params = []
for parameter_name in checking_parameters:
if backend == "Engine":
base_params.append(
engine.get_weights_by_name(parameter_name, truncate_size)
)
else:
base_params.append(
requests.get(
f"{url}/get_weights_by_name",
json={
"name": parameter_name,
"truncate_size": truncate_size,
},
).json()
)
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
# Shutdown the engine or terminate the server process.
if backend == "Engine":
engine.shutdown()
else:
terminate_process(process)
def assert_tied_weights(params_list, message, should_be_tied):
for params in params_list:
if should_be_tied:
assert np.allclose(params[0], params[-1]), message
else:
assert not np.allclose(params[0], params[-1]), message
def test_update_weights_from_distributed(
tp_size,
dp_size,
model_name,
backend,
state_dict_key_to_shape,
truncate_size,
checking_parameters,
):
tie_word_embeddings = (
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
)
print(
f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backend}"
)
param_queue = mp.Queue()
results = {}
context = mp.spawn(
init_process,
args=(
1 + tp_size * dp_size,
param_queue,
truncate_size,
state_dict_key_to_shape,
tp_size,
model_name,
backend,
checking_parameters,
tie_word_embeddings,
),
nprocs=1 + dp_size,
join=False,
)
while len(results) < 3 * (1 + dp_size):
try:
key, value = param_queue.get(timeout=5)
results[key] = value
except Exception as e:
if all(not p.is_alive() for p in context.processes):
break
context.join()
if len(results) != 3 * (1 + dp_size):
raise RuntimeError(
f"Expected {3 * (1 + dp_size)} parameters but got {len(results)}"
)
params = {
"hf_instruct": results.get("hf_instruct_params"),
"hf_base": results.get("hf_base_params"),
"sgl_dp_1_instruct": results.get("sgl_dp_1_instruct_params"),
"sgl_dp_1_base": results.get("sgl_dp_1_base_params"),
"broadcast_time": results.get("broadcast_time"),
"update_sgl_dp_1_time": results.get("update_sgl_dp_1_time"),
}
if dp_size == 2:
dp2_params = {
"sgl_dp_2_instruct": results.get("sgl_dp_2_instruct_params"),
"sgl_dp_2_base": results.get("sgl_dp_2_base_params"),
"update_sgl_dp_2_time": results.get("update_sgl_dp_2_time"),
}
assert all(v is not None for v in dp2_params.values())
params.update(dp2_params)
# Check the correctness of weights update by verifying
# the weights of instruct model and base model.
for i in range(len(params["hf_instruct"])):
verify_params_close(
params["hf_instruct"][i],
params["sgl_dp_1_instruct"][i],
f"sgl_dp_1_instruct_params rank {i}",
)
verify_params_close(
params["hf_base"][i],
params["sgl_dp_1_base"][i],
f"sgl_dp_1_base_params rank {i}",
)
verify_params_not_close(
params["hf_instruct"][i],
params["hf_base"][i],
f"hf_instruct_params rank {i}",
)
if dp_size == 2:
verify_params_close(
params["hf_base"][i],
params["sgl_dp_2_base"][i],
f"sgl_dp_2_base_params rank {i}",
)
verify_params_close(
params["hf_instruct"][i],
params["sgl_dp_2_instruct"][i],
f"sgl_dp_2_instruct_params rank {i}",
)
assert len(params["hf_instruct"]) == len(
params["hf_base"]
), "hf_instruct_params and hf_base_params have different lengths"
# Check if the weights of lm_head are tied with embed_tokens.
params_to_check = [
(
params["hf_instruct"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["hf_base"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["sgl_dp_1_instruct"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["sgl_dp_1_base"],
"lm_head.weight is not tied with embed_tokens.weight",
),
]
if dp_size == 2:
params_to_check.extend(
[
(
params["sgl_dp_2_instruct"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["sgl_dp_2_base"],
"lm_head.weight is not tied with embed_tokens.weight",
),
]
)
assert_tied_weights(
[params for params, _ in params_to_check],
(
"lm_head.weight is not tied with embed_tokens.weight"
if tie_word_embeddings
else "lm_head.weight is tied with embed_tokens.weight"
),
tie_word_embeddings,
)
# Time limit for broadcast and update on CI is 3 / 6
# On local H100, it's 1 / 2
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
assert (
params["broadcast_time"] < time_limit
), f"broadcast_time exceeds time limit {time_limit}s"
assert (
params["update_sgl_dp_1_time"] < time_limit
), f"update_sgl_dp_one_time exceeds time limit {time_limit}s"
if dp_size == 2:
assert (
params["update_sgl_dp_2_time"] < time_limit
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
# Delete the context and close the parameter queue.
del context
param_queue.close()
param_queue.join_thread()
gc.collect()
torch.cuda.empty_cache()
class TestUpdateWeightsFromDistributed(CustomTestCase):
def test_update_weights_from_distributed(self):
assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
# test_suits : tp, dp, model_name, backend
if is_in_ci():
mode = random.choice(["Engine", "Server"])
test_suits = [
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, mode),
]
else:
test_suits = [
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
(1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever"),
]
if torch.cuda.device_count() >= 4:
test_suits.extend(
[
(2, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
(1, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
]
)
if torch.cuda.device_count() >= 5:
test_suits.extend(
[
(2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
(2, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
]
)
model_state_dict_shapes = {}
test_models = [test_suit[2] for test_suit in test_suits]
for model_name in test_models:
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="bfloat16"
).to("cuda:0")
state_dict = model.state_dict()
state_dict_keys = list(state_dict.keys())
model_state_dict_shapes[model_name] = {
key: state_dict[key].shape for key in state_dict_keys
}
del model
gc.collect()
torch.cuda.empty_cache()
truncate_size = 10
checking_parameters = [
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.1.self_attn.q_proj.weight",
"model.layers.2.self_attn.k_proj.weight",
"model.layers.3.self_attn.v_proj.weight",
"model.layers.4.self_attn.o_proj.weight",
"model.layers.5.mlp.gate_proj.weight",
"model.layers.6.mlp.up_proj.weight",
"model.layers.7.mlp.down_proj.weight",
"model.layers.8.post_attention_layernorm.weight",
"model.norm.weight",
"lm_head.weight",
]
for tp_size, dp_size, model_name, backend in test_suits:
test_update_weights_from_distributed(
tp_size,
dp_size,
model_name,
backend,
model_state_dict_shapes[model_name],
truncate_size,
checking_parameters,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,178 @@
import gc
import time
import unittest
import torch
import sglang as sgl
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
def test_update_weights_from_tensor(tp_size):
assert torch.cuda.device_count() >= tp_size, f"At least {tp_size} GPUs are required"
torch.cuda.empty_cache()
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, tp_size=tp_size)
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)]
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
memory_before = torch.cuda.memory_allocated()
new_tensor = torch.full((16384, 2048), 1.5, device="cuda")
time_start = time.perf_counter()
engine.update_weights_from_tensor([(x, new_tensor) for x in param_names])
print(f"Time delta: {time.perf_counter() - time_start:.03f}")
for param_name in param_names[:3]:
_check_param(engine, param_name, [1.5] * 5)
engine.shutdown()
del new_tensor
gc.collect()
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
memory_after = torch.cuda.memory_allocated()
assert (
memory_after <= memory_before + 1024
), f"Memory leak detected: {memory_after - memory_before} bytes"
class TestUpdateWeightsFromTensor(CustomTestCase):
def test_update_weights_from_tensor(self):
tp_sizes = [1, 2]
for tp_size in tp_sizes:
if torch.cuda.device_count() < tp_size:
continue
with self.subTest(tp_size=tp_size):
test_update_weights_from_tensor(tp_size)
def test_update_weights_from_tensor_load_format_direct(self):
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
write_param_names = [
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
]
read_param_names = [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
]
_check_param(
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
)
new_tensor = torch.full((3072, 2048), 1.5)
engine.update_weights_from_tensor(
[
(write_param_name, new_tensor.clone())
for write_param_name in write_param_names
],
load_format="direct",
)
for read_param_name in read_param_names[:3]:
_check_param(engine, read_param_name, [1.5] * 5)
engine.shutdown()
def test_update_weights_from_tensor_load_format_custom(self):
custom_loader_name = (
"sglang.srt.model_executor.model_runner._model_load_weights_direct"
)
engine = sgl.Engine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
custom_weight_loader=[custom_loader_name],
)
write_param_names = [
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
]
read_param_names = [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
]
_check_param(
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
)
new_tensor = torch.full((3072, 2048), 1.5)
engine.update_weights_from_tensor(
[
(write_param_name, new_tensor.clone())
for write_param_name in write_param_names
],
load_format=custom_loader_name,
)
for read_param_name in read_param_names[:3]:
_check_param(engine, read_param_name, [1.5] * 5)
engine.shutdown()
def test_update_weights_from_tensor_load_format_flattened_bucket(self):
"""Test updating weights using flattened_bucket format"""
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
# Create a small set of parameters for testing
param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 10)]
# Check original values
_check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
# Create new tensors with different values
new_tensors = []
for _, name in enumerate(param_names):
# Create tensors with different values for each parameter
value = 2.0 # Different value for each parameter
new_tensor = torch.full((16384, 2048), value, device="cuda")
new_tensors.append((name, new_tensor))
# Create a flattened bucket
flattened_bucket = FlattenedTensorBucket(named_tensors=new_tensors)
# Extract the flattened tensor and metadata in the format expected by model_runner
flattened_tensor = flattened_bucket.get_flattened_tensor()
metadata = flattened_bucket.get_metadata()
# Create the dict format expected by _update_weights_from_flattened_bucket
bucket_dict = {"flattened_tensor": flattened_tensor, "metadata": metadata}
# Serialize the bucket data
from sglang.srt.utils import MultiprocessingSerializer
serialized_bucket = MultiprocessingSerializer.serialize(
bucket_dict, output_str=True
)
# Create a list where each rank contains the same serialized data
# This simulates the distributed environment where each rank has the same data
serialized_bucket_list = [serialized_bucket]
# Update weights using flattened_bucket format
time_start = time.perf_counter()
engine.update_weights_from_tensor(
named_tensors=serialized_bucket_list, load_format="flattened_bucket"
)
update_time = time.perf_counter() - time_start
print(f"Flattened bucket update time: {update_time:.03f}")
# Verify the weights were updated correctly
for i, param_name in enumerate(param_names):
_check_param(engine, param_name, [2.0] * 5)
engine.shutdown()
def _check_param(engine, param_name, expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
assert torch.allclose(
actual_values, torch.tensor(expect_values), atol=0.002
), f"{actual_values=}"
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,277 @@
import multiprocessing
import multiprocessing as mp
import os
import random
import traceback
import unittest
from multiprocessing import Process
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.verl_engine import VerlEngine
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import is_port_available
from sglang.test.runners import (
HFRunner,
SRTRunner,
check_close_model_outputs,
get_dtype_str,
)
from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci
_MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
_TORCH_DTYPE = torch.float16
# Set to false to temporarily debug issues unrelated to weight update
_ENABLE_UPDATE_WEIGHTS = True
# _ENABLE_UPDATE_WEIGHTS = False
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
ALL_MODELS = [
dict(model_path="meta-llama/Llama-3.2-1B-Instruct"),
dict(model_path="Qwen/Qwen2-1.5B"),
dict(model_path="allenai/OLMo-1B-0724-hf"),
dict(model_path="allenai/OLMo-2-1124-7B-Instruct"),
dict(
model_path="ibm-granite/granite-3.0-2b-instruct",
prefill_tolerance=0.22,
decode_tolerance=0.22,
),
]
class TestVerlEngine(CustomTestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
def assert_fragment_e2e_execution(
self,
index: int,
model_path: str,
mem_fraction_static: float = 0.4,
dp_size: int = 1,
tp_size: int = 2,
tight_memory: bool = False,
prefill_tolerance: float = 0.1,
decode_tolerance: float = 0.1,
):
master_port = find_available_port(23456)
print(f"assert_fragment_e2e_execution START {index=} {model_path=}")
processes = []
output_reader, output_writer = mp.Pipe(duplex=False)
world_size = dp_size * tp_size
for rank in range(world_size):
p = Process(
target=_run_subprocess,
kwargs=dict(
rank=rank,
dp_size=dp_size,
tp_size=tp_size,
master_port=master_port,
output_writer=output_writer,
model_path=model_path,
mem_fraction_static=mem_fraction_static,
tight_memory=tight_memory,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
),
)
p.start()
processes.append(p)
for _ in range(tp_size):
self.assertTrue(
output_reader.recv(),
f"Subprocess has error, please see logs above. ({index=} {model_path=})",
)
for p in processes:
p.join()
def test_ci_models(self):
ci_models = [random.choice(ALL_MODELS)]
for index, model_info in enumerate(ci_models):
self.assert_fragment_e2e_execution(index=index, **model_info)
def test_others(self):
if is_in_ci():
return
for index, model_info in enumerate(ALL_MODELS):
self.assert_fragment_e2e_execution(index=index, **model_info)
# def test_adhoc(self):
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
def _run_subprocess(
rank: int,
dp_size: int,
tp_size: int,
master_port: int,
output_writer,
model_path: str,
mem_fraction_static: float,
tight_memory: bool,
prefill_tolerance: float,
decode_tolerance: float,
):
try:
print(f"subprocess[{rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
torch.distributed.init_process_group(rank=rank, world_size=dp_size * tp_size)
torch.cuda.set_device(rank)
base_gpu_id = rank // tp_size * tp_size
mesh_kwargs = dict(
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
)
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
print(
f"subprocess[{rank=},{base_gpu_id=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
)
# hf model is used for comparison
hf_model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
).cuda()
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_outputs = HFRunner.forward_generation_raw(
base_model=hf_model,
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
tokenizer=hf_tokenizer,
lora_paths=None,
torch_dtype=_TORCH_DTYPE,
output_str_only=False,
)
print(
f"subprocess[{rank=}] call hf.forward {hf_outputs=}",
flush=True,
)
if _ENABLE_UPDATE_WEIGHTS:
if tight_memory:
hf_model.cpu()
torch.cuda.empty_cache()
# test update weights
print(f"subprocess[{rank=}] get_fsdp_state_dict", flush=True)
fsdp_state_dict = _get_fsdp_state_dict(
hf_model=hf_model, world_size=dp_size * tp_size
)
engine = VerlEngine(
model_path=model_path,
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
mem_fraction_static=mem_fraction_static,
random_seed=42,
base_gpu_id=base_gpu_id,
trust_remote_code=True,
dtype=get_dtype_str(_TORCH_DTYPE),
device_mesh_cpu=inference_device_mesh_cpu["tp"],
)
print(f"subprocess[{rank=}] {engine=}", flush=True)
if _ENABLE_UPDATE_WEIGHTS:
print(f"subprocess[{rank=}] call update_weights_from_tensor", flush=True)
engine.update_weights_from_tensor(
[(k, v) for k, v in fsdp_state_dict.items()]
)
for enable_batch in [False, True]:
if enable_batch:
fn = SRTRunner.batch_forward_generation_raw
else:
fn = SRTRunner.forward_generation_raw
srt_outputs = fn(
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
lora_paths=None,
engine=engine,
)
print(
f"subprocess[{rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
flush=True,
)
check_close_model_outputs(
hf_outputs=hf_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
rouge_l_tolerance=1,
check_logprobs=not enable_batch,
debug_text=f"{enable_batch=} {rank=}",
)
execution_ok = True
except Exception as e:
print(f"subprocess[{rank=}] has error: {e}", flush=True)
traceback.print_exc()
execution_ok = False
output_writer.send(execution_ok)
output_writer.close()
if "engine" in locals() and engine is not None:
engine.shutdown()
print(f"subprocess[{rank=}] end", flush=True)
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
def _get_fsdp_state_dict(hf_model, world_size: int):
device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
)
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
fsdp_model = FSDP(
hf_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)
print(f"{fsdp_model=}")
FSDP.set_state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
return fsdp_model.state_dict()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,291 @@
import multiprocessing
import multiprocessing as mp
import os
import random
import traceback
import unittest
from multiprocessing import Process
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
from transformers import AutoModelForCausalLM
from sglang.srt.entrypoints.verl_engine import VerlEngine
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import is_port_available
from sglang.test.runners import (
HFRunner,
SRTRunner,
check_close_model_outputs,
get_dtype_str,
)
from sglang.test.test_utils import CustomTestCase, find_available_port, is_in_ci
_MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
_TORCH_DTYPE = torch.float16
# Set to false to temporarily debug issues unrelated to weight update
_ENABLE_UPDATE_WEIGHTS = True
# _ENABLE_UPDATE_WEIGHTS = False
# TODO maybe we should add more other models? should we keep it in sync with test_generation_models.py?
ALL_MODELS = [
dict(
model_path="Qwen/Qwen2.5-0.5B",
dp_size=2,
tp_size=2, # default to 2
),
dict(
model_path="Qwen/Qwen2.5-14B-Instruct",
mem_fraction_static=0.7,
dp_size=2,
tp_size=2,
tight_memory=True,
decode_tolerance=1.3,
), # test_generation_models.py same config (qwen + tp=8) gives 1.22 decode error
dict(
model_path="THUDM/glm-4-9b-chat",
mem_fraction_static=0.5,
dp_size=2,
tp_size=2,
tight_memory=True,
),
# Fail to run these models in test_generation_models.py, need to fix that first
# dict(model_path="openai-community/gpt2"),
# dict(model_path="microsoft/Phi-3-small-8k-instruct"),
]
class TestVerlEngine(CustomTestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
def assert_fragment_e2e_execution(
self,
index: int,
model_path: str,
mem_fraction_static: float = 0.4,
dp_size: int = 1,
tp_size: int = 2,
tight_memory: bool = False,
prefill_tolerance: float = 0.1,
decode_tolerance: float = 0.1,
):
master_port = find_available_port(23456)
print(f"assert_fragment_e2e_execution START {index=} {model_path=}")
processes = []
output_reader, output_writer = mp.Pipe(duplex=False)
world_size = dp_size * tp_size
for rank in range(world_size):
p = Process(
target=_run_subprocess,
kwargs=dict(
rank=rank,
dp_size=dp_size,
tp_size=tp_size,
master_port=master_port,
output_writer=output_writer,
model_path=model_path,
mem_fraction_static=mem_fraction_static,
tight_memory=tight_memory,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
),
)
p.start()
processes.append(p)
for _ in range(tp_size):
self.assertTrue(
output_reader.recv(),
f"Subprocess has error, please see logs above. ({index=} {model_path=})",
)
for p in processes:
p.join()
def test_ci_models(self):
ci_models = [random.choice(ALL_MODELS)]
for index, model_info in enumerate(ci_models):
self.assert_fragment_e2e_execution(index=index, **model_info)
def test_others(self):
if is_in_ci():
return
for index, model_info in enumerate(ALL_OTHER_MODELS):
self.assert_fragment_e2e_execution(index=index, **model_info)
# def test_adhoc(self):
# self.assert_fragment_e2e_execution(index=0, model_path="meta-llama/Llama-3.2-1B-Instruct")
def _run_subprocess(
rank: int,
dp_size: int,
tp_size: int,
master_port: int,
output_writer,
model_path: str,
mem_fraction_static: float,
tight_memory: bool,
prefill_tolerance: float,
decode_tolerance: float,
):
try:
print(f"subprocess[{rank=}] Start {os.environ.get('CUDA_VISIBLE_DEVICES')=}")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
torch.distributed.init_process_group(rank=rank, world_size=dp_size * tp_size)
torch.cuda.set_device(rank)
base_gpu_id = rank // tp_size * tp_size
mesh_kwargs = dict(
mesh_shape=(dp_size, tp_size, 1), mesh_dim_names=["dp", "tp", "pp"]
)
inference_device_mesh_device = init_device_mesh("cuda", **mesh_kwargs)
inference_device_mesh_cpu = init_device_mesh("cpu", **mesh_kwargs)
print(
f"subprocess[{rank=},{base_gpu_id=}] {inference_device_mesh_device=} {inference_device_mesh_cpu=}"
)
# hf model is used for comparison
hf_model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=_TORCH_DTYPE, trust_remote_code=True
).cuda()
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_outputs = HFRunner.forward_generation_raw(
base_model=hf_model,
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
tokenizer=hf_tokenizer,
lora_paths=None,
torch_dtype=_TORCH_DTYPE,
output_str_only=False,
)
print(
f"subprocess[{rank=}] call hf.forward {hf_outputs=}",
flush=True,
)
if _ENABLE_UPDATE_WEIGHTS:
if tight_memory:
hf_model.cpu()
torch.cuda.empty_cache()
# test update weights
print(f"subprocess[{rank=}] get_fsdp_state_dict", flush=True)
fsdp_state_dict = _get_fsdp_state_dict(
hf_model=hf_model, world_size=dp_size * tp_size
)
engine = VerlEngine(
model_path=model_path,
load_format="dummy" if _ENABLE_UPDATE_WEIGHTS else "auto",
mem_fraction_static=mem_fraction_static,
random_seed=42,
base_gpu_id=base_gpu_id,
trust_remote_code=True,
dtype=get_dtype_str(_TORCH_DTYPE),
device_mesh_cpu=inference_device_mesh_cpu["tp"],
)
print(f"subprocess[{rank=}] {engine=}", flush=True)
if _ENABLE_UPDATE_WEIGHTS:
print(f"subprocess[{rank=}] call update_weights_from_tensor", flush=True)
engine.update_weights_from_tensor(
[(k, v) for k, v in fsdp_state_dict.items()]
)
for enable_batch in [False, True]:
if enable_batch:
fn = SRTRunner.batch_forward_generation_raw
else:
fn = SRTRunner.forward_generation_raw
srt_outputs = fn(
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
lora_paths=None,
engine=engine,
)
print(
f"subprocess[{rank=}] call srt.forward {enable_batch=} {srt_outputs=}",
flush=True,
)
check_close_model_outputs(
hf_outputs=hf_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=prefill_tolerance,
decode_tolerance=decode_tolerance,
rouge_l_tolerance=1,
check_logprobs=not enable_batch,
debug_text=f"{enable_batch=} {rank=}",
)
execution_ok = True
except Exception as e:
print(f"subprocess[{rank=}] has error: {e}", flush=True)
traceback.print_exc()
execution_ok = False
output_writer.send(execution_ok)
output_writer.close()
if "engine" in locals() and engine is not None:
engine.shutdown()
print(f"subprocess[{rank=}] end", flush=True)
# Adapted from https://github.com/volcengine/verl/blob/main/tests/rollout/run_fsdp_vllm.py
def _get_fsdp_state_dict(hf_model, world_size: int):
device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
)
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
fsdp_model = FSDP(
hf_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)
print(f"{fsdp_model=}")
FSDP.set_state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
return fsdp_model.state_dict()
if __name__ == "__main__":
unittest.main()