[Feature] SPMD for SGLang + Verl (#3852)

This commit is contained in:
fzyzcjy
2025-03-01 01:53:10 +08:00
committed by GitHub
parent bac414ab53
commit e3e0bc50a9
19 changed files with 890 additions and 202 deletions

View File

@@ -26,6 +26,34 @@ class TestUpdateWeightsFromTensor(unittest.TestCase):
engine.shutdown()
def test_update_weights_from_tensor_load_format_direct(self):
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
write_param_names = [
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
]
read_param_names = [
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
]
_check_param(
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
)
new_tensor = torch.full((3072, 2048), 1.5)
engine.update_weights_from_tensor(
[
(write_param_name, new_tensor.clone())
for write_param_name in write_param_names
],
load_format="direct",
)
for read_param_name in read_param_names[:3]:
_check_param(engine, read_param_name, [1.5] * 5)
engine.shutdown()
def _check_param(engine, param_name, expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]