Clean up model loader (#1440)

This commit is contained in:
Lianmin Zheng
2024-09-16 18:16:27 -07:00
committed by GitHub
parent 93dffd699b
commit 27b557aea7
5 changed files with 33 additions and 80 deletions

View File

@@ -187,7 +187,7 @@ def allocate_init_ports(
cur_port += 1
if port is not None and ret_ports[0] != port:
logger.warn(
logger.warning(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
@@ -623,56 +623,7 @@ def set_ulimit(target_soft_limit=65535):
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
def is_llama3_405b_fp8_head_16(model_config):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
and model_config.hf_config.hidden_size == 16384
and model_config.hf_config.intermediate_size == 53248
and model_config.hf_config.num_hidden_layers == 126
and model_config.hf_config.num_key_value_heads == 16
and hasattr(model_config.hf_config, "quantization_config")
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
):
return True
return False
def monkey_patch_vllm_qvk_linear_loader():
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
from vllm.model_executor.layers.linear import QKVParallelLinear
origin_weight_loader = QKVParallelLinear.weight_loader
def get_original_weight(loaded_weight, head_dim):
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
dim = loaded_weight.shape[1]
for i in range(n_kv_head):
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
2 * i * head_dim : (2 * i + 1) * head_dim, :
]
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
return original_kv_weight
def weight_loader_srt(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
if (
loaded_shard_id in ["k", "v"]
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
):
loaded_weight = get_original_weight(loaded_weight, self.head_size)
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
def add_api_key_middleware(app, api_key: str):