refactor model loader: initial refactor (#664)
This commit is contained in:
@@ -92,3 +92,9 @@ python3 run_all.py
|
|||||||
cd test/srt
|
cd test/srt
|
||||||
python test_openai_server.py
|
python test_openai_server.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Format
|
||||||
|
pip3 install pre-commit
|
||||||
|
cd sglang
|
||||||
|
pre-commit install
|
||||||
|
pre-commit run --all-files
|
||||||
@@ -123,6 +123,15 @@ class ModelRunner:
|
|||||||
if self.model_config.model_overide_args is not None:
|
if self.model_config.model_overide_args is not None:
|
||||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.server_args.efficient_weight_load
|
||||||
|
and "llama" in self.server_args.model_path.lower()
|
||||||
|
and self.server_args.quantization == "fp8"
|
||||||
|
):
|
||||||
|
from sglang.srt.model_loader.model_loader import get_model
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
|
||||||
self.model = get_model(
|
self.model = get_model(
|
||||||
model_config=vllm_model_config,
|
model_config=vllm_model_config,
|
||||||
device_config=device_config,
|
device_config=device_config,
|
||||||
@@ -237,7 +246,16 @@ class ModelRunner:
|
|||||||
self.cuda_graph_runner = CudaGraphRunner(
|
self.cuda_graph_runner = CudaGraphRunner(
|
||||||
self, max_batch_size_to_capture=max(batch_size_list)
|
self, max_batch_size_to_capture=max(batch_size_list)
|
||||||
)
|
)
|
||||||
self.cuda_graph_runner.capture(batch_size_list)
|
logger.info(f"Capture for batch sizes {batch_size_list}")
|
||||||
|
try:
|
||||||
|
self.cuda_graph_runner.capture(batch_size_list)
|
||||||
|
except:
|
||||||
|
raise Exception(
|
||||||
|
f"Capture cuda graph failed. Possible solutions:\n"
|
||||||
|
f"1. disable cuda graph by --disable-cuda-graph\n"
|
||||||
|
f"2. set --mem-fraction-static to a smaller value\n"
|
||||||
|
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_decode(self, batch: Batch):
|
def forward_decode(self, batch: Batch):
|
||||||
|
|||||||
@@ -304,6 +304,12 @@ class ModelTpServer:
|
|||||||
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
||||||
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
self.max_total_num_tokens - 128 - len(req.origin_input_ids),
|
||||||
)
|
)
|
||||||
|
if req.sampling_params.max_new_tokens < 0:
|
||||||
|
req.origin_input_ids = req.origin_input_ids[
|
||||||
|
: self.max_total_num_tokens - 128
|
||||||
|
]
|
||||||
|
logger.error("Request longer than memory pool size, truncated!!!")
|
||||||
|
|
||||||
self.forward_queue.append(req)
|
self.forward_queue.append(req)
|
||||||
|
|
||||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ def _initialize_model(
|
|||||||
config=model_config.hf_config,
|
config=model_config.hf_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
efficient_weight_load=True,
|
||||||
**_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
|
**_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,6 @@ from vllm.distributed import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
|
||||||
RowParallelLinear,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@@ -32,6 +27,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
|
MergedColumnParallelLinear = None
|
||||||
|
QKVParallelLinear = None
|
||||||
|
RowParallelLinear = None
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -267,7 +266,25 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
config: LlamaConfig,
|
config: LlamaConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
efficient_weight_load=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
global MergedColumnParallelLinear
|
||||||
|
global QKVParallelLinear
|
||||||
|
global RowParallelLinear
|
||||||
|
|
||||||
|
if efficient_weight_load:
|
||||||
|
from sglang.srt.layers.linear import (
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
@@ -288,7 +305,30 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def get_module_name(self, name):
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id, num_shard)
|
||||||
|
("qkv_proj", "q_proj", "q", 3),
|
||||||
|
("qkv_proj", "k_proj", "k", 3),
|
||||||
|
("qkv_proj", "v_proj", "v", 3),
|
||||||
|
("gate_up_proj", "gate_proj", 0, 2),
|
||||||
|
("gate_up_proj", "up_proj", 1, 2),
|
||||||
|
]
|
||||||
|
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
||||||
|
if weight_name in name:
|
||||||
|
return (
|
||||||
|
name.replace(weight_name, param_name)[: -len(".weight")],
|
||||||
|
num_shard,
|
||||||
|
)
|
||||||
|
return name[: -len(".weight")], 1
|
||||||
|
|
||||||
|
def get_num_params(self):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
return len(params_dict)
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||||
|
):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
@@ -298,15 +338,14 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
if get_tensor_model_parallel_rank() == 0:
|
|
||||||
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
def load_weights_per_param(name, loaded_weight):
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
continue
|
return
|
||||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
return
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@@ -323,12 +362,21 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
return
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
return
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
if name is None or loaded_weight is None:
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
load_weights_per_param(name, loaded_weight)
|
||||||
|
else:
|
||||||
|
load_weights_per_param(name, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = LlamaForCausalLM
|
EntryClass = LlamaForCausalLM
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class ServerArgs:
|
|||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
attention_reduce_in_fp32: bool = False
|
attention_reduce_in_fp32: bool = False
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
|
efficient_weight_load: bool = False
|
||||||
|
|
||||||
# Distributed args
|
# Distributed args
|
||||||
nccl_init_addr: Optional[str] = None
|
nccl_init_addr: Optional[str] = None
|
||||||
@@ -327,6 +328,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--efficient-weight-load",
|
||||||
|
action="store_true",
|
||||||
|
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user