Fix bugs (fp8 checkpoints, triton cache manager) (#729)
This commit is contained in:
@@ -30,9 +30,11 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
is_llama3_405b_fp8,
|
||||
is_multimodal_model,
|
||||
monkey_patch_vllm_dummy_weight_loader,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
monkey_patch_vllm_qvk_linear_loader,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("srt.model_runner")
|
||||
@@ -118,6 +120,13 @@ class ModelRunner:
|
||||
seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
|
||||
if is_llama3_405b_fp8(self.model_config):
|
||||
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
||||
self.model_config.hf_config.num_key_value_heads = 8
|
||||
vllm_model_config.hf_config.num_key_value_heads = 8
|
||||
monkey_patch_vllm_qvk_linear_loader()
|
||||
|
||||
self.dtype = vllm_model_config.dtype
|
||||
if self.model_config.model_overide_args is not None:
|
||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||
|
||||
@@ -202,15 +202,12 @@ def launch_server(
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
|
||||
if server_args.tp_size // server_args.dp_size > 1:
|
||||
if server_args.tp_size * server_args.dp_size > 1:
|
||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
if server_args.enable_torch_compile:
|
||||
_set_torch_compile_config()
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch.distributed as dist
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from torch.nn.parameter import Parameter
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
default_cache_dir,
|
||||
@@ -471,7 +472,7 @@ def maybe_set_triton_cache_manager() -> None:
|
||||
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
|
||||
if cache_manger is None:
|
||||
manager = "sglang.srt.utils:CustomCacheManager"
|
||||
logger.info("Setting Triton cache manager to: %s", manager)
|
||||
logger.debug("Setting Triton cache manager to: %s", manager)
|
||||
os.environ["TRITON_CACHE_MANAGER"] = manager
|
||||
|
||||
|
||||
@@ -615,3 +616,51 @@ def set_ulimit(target_soft_limit=65535):
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||
|
||||
|
||||
def is_llama3_405b_fp8(model_config):
|
||||
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
|
||||
if (
|
||||
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
||||
and model_config.hf_config.hidden_size == 16384
|
||||
and model_config.hf_config.intermediate_size == 53248
|
||||
and model_config.hf_config.num_hidden_layers == 126
|
||||
and model_config.hf_config.num_key_value_heads == 16
|
||||
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def monkey_patch_vllm_qvk_linear_loader():
|
||||
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
|
||||
origin_weight_loader = QKVParallelLinear.weight_loader
|
||||
|
||||
def get_original_weight(loaded_weight, head_dim):
|
||||
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
|
||||
dim = loaded_weight.shape[1]
|
||||
for i in range(n_kv_head):
|
||||
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
|
||||
2 * i * head_dim : (2 * i + 1) * head_dim, :
|
||||
]
|
||||
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
|
||||
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
|
||||
return original_kv_weight
|
||||
|
||||
def weight_loader_srt(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None,
|
||||
):
|
||||
if (
|
||||
loaded_shard_id in ["k", "v"]
|
||||
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
|
||||
):
|
||||
loaded_weight = get_original_weight(loaded_weight, self.head_size)
|
||||
|
||||
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
|
||||
|
||||
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
||||
|
||||
Reference in New Issue
Block a user