Add get weights by parameter name for llama (#2266)
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -128,6 +128,7 @@ jobs:
|
||||
python3 test_mla_fp8.py
|
||||
python3 test_dp_attention.py
|
||||
|
||||
|
||||
performance-test-1-gpu-part-1:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
runs-on: 1-gpu-runner
|
||||
@@ -242,6 +243,7 @@ jobs:
|
||||
cd test/srt
|
||||
python3 test_eval_accuracy_large.py
|
||||
|
||||
|
||||
accuracy-test-2-gpu:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
runs-on: 2-gpu-runner
|
||||
|
||||
2
3rdparty/amd/profiling/PROFILING.md
vendored
2
3rdparty/amd/profiling/PROFILING.md
vendored
@@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644
|
||||
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.
|
||||
|
||||
4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
|
||||
=======
|
||||
-------
|
||||
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
|
||||
|
||||
@@ -365,6 +365,17 @@ class UpdateWeightFromDiskReqOutput:
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqInput:
|
||||
name: str
|
||||
truncate_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetWeightsByNameReqOutput:
|
||||
parameter: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbortReq:
|
||||
# The request id
|
||||
|
||||
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
FlushCacheReq,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -511,6 +513,9 @@ class Scheduler:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightFromDiskReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
@@ -1373,6 +1378,10 @@ class Scheduler:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
|
||||
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -454,6 +456,23 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.get_weights_by_name_result = asyncio.Future()
|
||||
if self.server_args.dp_size == 1:
|
||||
result = await self.get_weights_by_name_result
|
||||
return result.parameter
|
||||
else:
|
||||
self.get_weights_by_name_tmp = []
|
||||
result = await self.get_weights_by_name_result
|
||||
all_parameters = [r.parameter for r in result]
|
||||
return all_parameters
|
||||
|
||||
async def open_session(
|
||||
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -527,6 +546,7 @@ class TokenizerManager:
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
GetWeightsByNameReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
@@ -538,6 +558,16 @@ class TokenizerManager:
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.get_weights_by_name_result.set_result(recv_obj)
|
||||
else:
|
||||
self.get_weights_by_name_tmp.append(recv_obj)
|
||||
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
|
||||
self.get_weights_by_name_result.set_result(
|
||||
self.get_weights_by_name_tmp
|
||||
)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
|
||||
@@ -19,7 +19,10 @@ from typing import Optional
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -160,3 +163,9 @@ class TpModelWorker:
|
||||
recv_req.model_path, recv_req.load_format
|
||||
)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
)
|
||||
return parameter
|
||||
|
||||
@@ -23,7 +23,10 @@ from typing import Optional
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import UpdateWeightFromDiskReqInput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -208,6 +211,9 @@ class TpModelWorkerClient:
|
||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
def __delete__(self):
|
||||
self.input_queue.put((None, None))
|
||||
self.copy_queue.put((None, None, None))
|
||||
|
||||
@@ -20,13 +20,10 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import pkgutil
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from tokenize import tabsize
|
||||
from typing import Any, Optional, Type, Union
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
@@ -403,6 +400,23 @@ class ModelRunner:
|
||||
logger.info("Update weights end.")
|
||||
return True, "Succeeded to update model weights."
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
||||
|
||||
Only used for unit test with an unoptimized performance.
|
||||
For optimized performance, please use torch.save and torch.load.
|
||||
"""
|
||||
# TODO: (chenyang) Add support for Qwen models.
|
||||
try:
|
||||
return self.model.get_weights_by_name(
|
||||
name, truncate_size, tp_size=self.tp_size
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error when getting parameter {name}: {e}")
|
||||
return None
|
||||
|
||||
def init_lora_manager(self):
|
||||
self.lora_manager = LoRAManager(
|
||||
base_model=self.model,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -45,6 +46,8 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -305,6 +308,14 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
self.stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -349,15 +360,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return params_mapping.get(name, name)
|
||||
|
||||
def get_module_name_from_weight_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
@@ -370,6 +373,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
embed_tokens_weight = None
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -378,6 +382,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
load_tie_word_embeddings = (
|
||||
@@ -425,10 +430,79 @@ class LlamaForCausalLM(nn.Module):
|
||||
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
||||
param = self.lm_head.weight
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, embed_tokens_weight)
|
||||
if embed_tokens_weight is not None:
|
||||
weight_loader(param, embed_tokens_weight)
|
||||
|
||||
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
||||
|
||||
def get_weights_by_name(
|
||||
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
||||
|
||||
Only used for unit test with an unoptimized performance.
|
||||
For optimized performance, please use torch.save and torch.load.
|
||||
"""
|
||||
try:
|
||||
mapped_name = name
|
||||
mapped_shard_id = None
|
||||
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
mapped_name = name.replace(weight_name, param_name)
|
||||
mapped_shard_id = shard_id
|
||||
break
|
||||
params_dict = dict(self.named_parameters())
|
||||
if mapped_name in params_dict:
|
||||
param = params_dict[mapped_name]
|
||||
if mapped_shard_id is not None:
|
||||
if mapped_shard_id in ["q", "k", "v"]:
|
||||
num_heads = self.config.num_attention_heads // tp_size
|
||||
num_kv_heads = self.config.num_key_value_heads // tp_size
|
||||
head_dim = (
|
||||
self.config.hidden_size // self.config.num_attention_heads
|
||||
)
|
||||
if mapped_shard_id == "q":
|
||||
offset = 0
|
||||
size = num_heads * head_dim
|
||||
elif mapped_shard_id == "k":
|
||||
offset = num_heads * head_dim
|
||||
size = num_kv_heads * head_dim
|
||||
elif mapped_shard_id == "v":
|
||||
offset = (num_heads + num_kv_heads) * head_dim
|
||||
size = num_kv_heads * head_dim
|
||||
weight = param.data.narrow(0, offset, size)
|
||||
elif mapped_shard_id in [0, 1]:
|
||||
intermediate_size = self.config.intermediate_size
|
||||
hidden_size = self.config.hidden_size
|
||||
slice_size = intermediate_size // tp_size
|
||||
if mapped_shard_id == 0: # gate_proj
|
||||
offset = 0
|
||||
size = slice_size
|
||||
elif mapped_shard_id == 1: # up_proj
|
||||
offset = slice_size
|
||||
size = slice_size
|
||||
|
||||
weight = param.data.narrow(0, offset, size)
|
||||
else:
|
||||
weight = param.data
|
||||
else:
|
||||
weight = param.data
|
||||
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
||||
gathered_weights = [
|
||||
torch.zeros_like(weight) for _ in range(tp_size)
|
||||
]
|
||||
torch.distributed.all_gather(gathered_weights, weight)
|
||||
weight = torch.cat(gathered_weights, dim=1)
|
||||
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting weights by name {name} in LlamaForCausalLM: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
OpenSessionReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
)
|
||||
@@ -210,6 +211,24 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Get model parameter by name."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
if ret is None:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": "Get parameter by name failed"}},
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(ret, status_code=200)
|
||||
except Exception as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
@@ -269,6 +288,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
@time_func_latency
|
||||
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Handle a get parameter by name request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
@@ -938,3 +969,8 @@ class Engine:
|
||||
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
def get_weights_by_name(self, name, truncate_size=100):
|
||||
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(get_weights_by_name_request(obj, None))
|
||||
|
||||
@@ -38,6 +38,7 @@ suites = {
|
||||
"test_update_weights.py",
|
||||
"test_vision_openai_server.py",
|
||||
"test_session_control.py",
|
||||
"test_get_parameter_by_name.py",
|
||||
],
|
||||
"sampling/penaltylib": glob.glob(
|
||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||
|
||||
128
test/srt/test_get_parameter_by_name.py
Normal file
128
test/srt/test_get_parameter_by_name.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
from sglang.utils import terminate_process
|
||||
|
||||
|
||||
class TestUpdateWeights(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model, torch_dtype="bfloat16"
|
||||
).to("cuda:0")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
del cls.hf_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def init_backend(self, backend, dp, tp):
|
||||
self.engine = None
|
||||
self.process = None
|
||||
self.backend = backend
|
||||
self.dp = dp
|
||||
self.tp = tp
|
||||
if backend == "Engine":
|
||||
self.engine = sgl.Engine(
|
||||
model_path=self.model,
|
||||
random_seed=42,
|
||||
tp_size=self.tp,
|
||||
dp_size=self.dp,
|
||||
mem_fraction_static=0.85,
|
||||
)
|
||||
else:
|
||||
self.process = popen_launch_server(
|
||||
self.model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=(
|
||||
"--tp-size",
|
||||
str(tp),
|
||||
"--dp-size",
|
||||
str(dp),
|
||||
),
|
||||
)
|
||||
|
||||
def close_engine_and_server(self):
|
||||
if self.engine:
|
||||
self.engine.shutdown()
|
||||
if self.process:
|
||||
terminate_process(self.process)
|
||||
|
||||
def assert_update_weights_all_close(self, param_name, truncate_size):
|
||||
print(
|
||||
f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
|
||||
)
|
||||
param = self.hf_model.get_parameter(param_name)[:truncate_size]
|
||||
param_np = param.cpu().detach().float().numpy()
|
||||
|
||||
if self.backend == "Engine":
|
||||
engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
|
||||
engine_ret = self._process_return(engine_ret)
|
||||
np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)
|
||||
|
||||
if self.backend == "Runtime":
|
||||
runtime_ret = requests.get(
|
||||
f"{self.base_url}/get_weights_by_name",
|
||||
json={"name": param_name, "truncate_size": truncate_size},
|
||||
).json()
|
||||
runtime_ret = self._process_return(runtime_ret)
|
||||
np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)
|
||||
|
||||
@staticmethod
|
||||
def _process_return(ret):
|
||||
if isinstance(ret, list) and len(ret) == 2:
|
||||
print(f"running assert_allclose on data parallel")
|
||||
np.testing.assert_allclose(ret[0], ret[1])
|
||||
return np.array(ret[0])
|
||||
return np.array(ret)
|
||||
|
||||
def test_update_weights_unexist_model(self):
|
||||
test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)]
|
||||
|
||||
if torch.cuda.device_count() >= 2:
|
||||
test_suits.append(("Engine", 1, 2))
|
||||
test_suits.append(("Runtime", 2, 1))
|
||||
|
||||
if torch.cuda.device_count() >= 4:
|
||||
test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)])
|
||||
|
||||
parameters = [
|
||||
"model.embed_tokens.weight",
|
||||
"model.layers.0.input_layernorm.weight",
|
||||
"model.layers.1.self_attn.q_proj.weight",
|
||||
"model.layers.2.self_attn.k_proj.weight",
|
||||
"model.layers.3.self_attn.v_proj.weight",
|
||||
"model.layers.4.self_attn.o_proj.weight",
|
||||
"model.layers.5.mlp.gate_proj.weight",
|
||||
"model.layers.6.mlp.up_proj.weight",
|
||||
"model.layers.7.mlp.down_proj.weight",
|
||||
"model.layers.8.post_attention_layernorm.weight",
|
||||
"model.norm.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
|
||||
for test_suit in test_suits:
|
||||
self.init_backend(*test_suit)
|
||||
for param_name in parameters:
|
||||
self.assert_update_weights_all_close(param_name, 100)
|
||||
self.close_engine_and_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user