Files
sglang/test/srt/test_get_parameter_by_name.py

129 lines
4.2 KiB
Python
Raw Normal View History

import gc
import unittest
import numpy as np
import requests
import torch
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
from sglang.utils import terminate_process
class TestUpdateWeights(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.hf_model = AutoModelForCausalLM.from_pretrained(
cls.model, torch_dtype="bfloat16"
).to("cuda:0")
@classmethod
def tearDownClass(cls):
del cls.hf_model
gc.collect()
torch.cuda.empty_cache()
def init_backend(self, backend, dp, tp):
self.engine = None
self.process = None
self.backend = backend
self.dp = dp
self.tp = tp
if backend == "Engine":
self.engine = sgl.Engine(
model_path=self.model,
random_seed=42,
tp_size=self.tp,
dp_size=self.dp,
mem_fraction_static=0.85,
)
else:
self.process = popen_launch_server(
self.model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--tp-size",
str(tp),
"--dp-size",
str(dp),
),
)
def close_engine_and_server(self):
if self.engine:
self.engine.shutdown()
if self.process:
terminate_process(self.process)
def assert_update_weights_all_close(self, param_name, truncate_size):
print(
f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
)
param = self.hf_model.get_parameter(param_name)[:truncate_size]
param_np = param.cpu().detach().float().numpy()
if self.backend == "Engine":
engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
engine_ret = self._process_return(engine_ret)
np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)
if self.backend == "Runtime":
runtime_ret = requests.get(
f"{self.base_url}/get_weights_by_name",
json={"name": param_name, "truncate_size": truncate_size},
).json()
runtime_ret = self._process_return(runtime_ret)
np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)
@staticmethod
def _process_return(ret):
if isinstance(ret, list) and len(ret) == 2:
print(f"running assert_allclose on data parallel")
np.testing.assert_allclose(ret[0], ret[1])
return np.array(ret[0])
return np.array(ret)
def test_update_weights_unexist_model(self):
test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)]
if torch.cuda.device_count() >= 2:
test_suits.append(("Engine", 1, 2))
test_suits.append(("Runtime", 2, 1))
if torch.cuda.device_count() >= 4:
test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)])
parameters = [
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.1.self_attn.q_proj.weight",
"model.layers.2.self_attn.k_proj.weight",
"model.layers.3.self_attn.v_proj.weight",
"model.layers.4.self_attn.o_proj.weight",
"model.layers.5.mlp.gate_proj.weight",
"model.layers.6.mlp.up_proj.weight",
"model.layers.7.mlp.down_proj.weight",
"model.layers.8.post_attention_layernorm.weight",
"model.norm.weight",
"lm_head.weight",
]
for test_suit in test_suits:
self.init_backend(*test_suit)
for param_name in parameters:
self.assert_update_weights_all_close(param_name, 100)
self.close_engine_and_server()
if __name__ == "__main__":
unittest.main()