refactor model loader: initial refactor (#664)
This commit is contained in:
@@ -123,6 +123,15 @@ class ModelRunner:
|
||||
if self.model_config.model_overide_args is not None:
|
||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||
|
||||
if (
|
||||
self.server_args.efficient_weight_load
|
||||
and "llama" in self.server_args.model_path.lower()
|
||||
and self.server_args.quantization == "fp8"
|
||||
):
|
||||
from sglang.srt.model_loader.model_loader import get_model
|
||||
else:
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
||||
self.model = get_model(
|
||||
model_config=vllm_model_config,
|
||||
device_config=device_config,
|
||||
@@ -237,7 +246,16 @@ class ModelRunner:
|
||||
self.cuda_graph_runner = CudaGraphRunner(
|
||||
self, max_batch_size_to_capture=max(batch_size_list)
|
||||
)
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
logger.info(f"Capture for batch sizes {batch_size_list}")
|
||||
try:
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
except:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed. Possible solutions:\n"
|
||||
f"1. disable cuda graph by --disable-cuda-graph\n"
|
||||
f"2. set --mem-fraction-static to a smaller value\n"
|
||||
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_decode(self, batch: Batch):
|
||||
|
||||
@@ -304,6 +304,12 @@ class ModelTpServer:
|
||||
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
||||
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
||||
)
|
||||
if req.sampling_params.max_new_tokens < 0:
|
||||
req.origin_input_ids = req.origin_input_ids[
|
||||
: self.max_total_num_tokens - 128
|
||||
]
|
||||
logger.error("Request longer than memory pool size, truncated!!!")
|
||||
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
|
||||
@@ -91,6 +91,7 @@ def _initialize_model(
|
||||
config=model_config.hf_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
efficient_weight_load=True,
|
||||
**_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
|
||||
)
|
||||
|
||||
|
||||
@@ -15,11 +15,6 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -32,6 +27,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
MergedColumnParallelLinear = None
|
||||
QKVParallelLinear = None
|
||||
RowParallelLinear = None
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -267,7 +266,25 @@ class LlamaForCausalLM(nn.Module):
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
global MergedColumnParallelLinear
|
||||
global QKVParallelLinear
|
||||
global RowParallelLinear
|
||||
|
||||
if efficient_weight_load:
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
@@ -288,7 +305,30 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def get_module_name(self, name):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id, num_shard)
|
||||
("qkv_proj", "q_proj", "q", 3),
|
||||
("qkv_proj", "k_proj", "k", 3),
|
||||
("qkv_proj", "v_proj", "v", 3),
|
||||
("gate_up_proj", "gate_proj", 0, 2),
|
||||
("gate_up_proj", "up_proj", 1, 2),
|
||||
]
|
||||
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||
if weight_name in name:
|
||||
return (
|
||||
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||
num_shard,
|
||||
)
|
||||
return name[: -len(".weight")], 1
|
||||
|
||||
def get_num_params(self):
|
||||
params_dict = dict(self.named_parameters())
|
||||
return len(params_dict)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -298,15 +338,14 @@ class LlamaForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
return
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
return
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -323,12 +362,21 @@ class LlamaForCausalLM(nn.Module):
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
return
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name is None or loaded_weight is None:
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
else:
|
||||
load_weights_per_param(name, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = LlamaForCausalLM
|
||||
|
||||
@@ -57,6 +57,7 @@ class ServerArgs:
|
||||
disable_disk_cache: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
efficient_weight_load: bool = False
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
@@ -327,6 +328,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--efficient-weight-load",
|
||||
action="store_true",
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user