support custom weight loader for model runner (#7122)
Co-authored-by: kavioyu <kavioyu@tencent.com>
This commit is contained in:
@@ -78,6 +78,40 @@ class TestUpdateWeightsFromTensor(CustomTestCase):
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
def test_update_weights_from_tensor_load_format_custom(self):
|
||||
custom_loader_name = (
|
||||
"sglang.srt.model_executor.model_runner._model_load_weights_direct"
|
||||
)
|
||||
engine = sgl.Engine(
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
custom_weight_loader=[custom_loader_name],
|
||||
)
|
||||
|
||||
write_param_names = [
|
||||
f"model.layers.{i}.self_attn.qkv_proj.weight" for i in range(6, 16)
|
||||
]
|
||||
read_param_names = [
|
||||
f"model.layers.{i}.self_attn.k_proj.weight" for i in range(6, 16)
|
||||
]
|
||||
|
||||
_check_param(
|
||||
engine, read_param_names[0], [-0.0198, 0.0227, 0.0168, 0.0232, -0.0178]
|
||||
)
|
||||
|
||||
new_tensor = torch.full((3072, 2048), 1.5)
|
||||
engine.update_weights_from_tensor(
|
||||
[
|
||||
(write_param_name, new_tensor.clone())
|
||||
for write_param_name in write_param_names
|
||||
],
|
||||
load_format=custom_loader_name,
|
||||
)
|
||||
|
||||
for read_param_name in read_param_names[:3]:
|
||||
_check_param(engine, read_param_name, [1.5] * 5)
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
def _check_param(engine, param_name, expect_values):
|
||||
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
|
||||
|
||||
Reference in New Issue
Block a user