fix: fp8 config (#723)
This commit is contained in:
@@ -15,6 +15,7 @@ from flashinfer import (
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.config import DeviceConfig, LoadConfig
|
||||
from vllm.config import ModelConfig as VllmModelConfig
|
||||
from vllm.distributed import (
|
||||
@@ -22,6 +23,7 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
@@ -38,6 +40,18 @@ from sglang.srt.utils import (
|
||||
logger = logging.getLogger("srt.model_runner")
|
||||
|
||||
|
||||
def is_llama3_405b_fp8(model_config):
|
||||
if (
|
||||
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
||||
and model_config.hf_config.hidden_size == 16384
|
||||
and model_config.hf_config.intermediate_size == 53248
|
||||
and model_config.hf_config.num_hidden_layers == 126
|
||||
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -118,6 +132,9 @@ class ModelRunner:
|
||||
seed=42,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
if is_llama3_405b_fp8(self.model_config):
|
||||
self.model_config.hf_config.num_key_value_heads = 8
|
||||
vllm_model_config.hf_config.num_key_value_heads = 8
|
||||
self.dtype = vllm_model_config.dtype
|
||||
if self.model_config.model_overide_args is not None:
|
||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||
@@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
return model_arch_name_to_cls[model_arch]
|
||||
|
||||
|
||||
def get_original_weight(loaded_weight, head_dim):
|
||||
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
|
||||
dim = loaded_weight.shape[1]
|
||||
for i in range(n_kv_head):
|
||||
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
|
||||
2 * i * head_dim : (2 * i + 1) * head_dim, :
|
||||
]
|
||||
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
|
||||
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
|
||||
return original_kv_weight
|
||||
|
||||
|
||||
def get_weight_loader_srt(weight_loader):
|
||||
def weight_loader_srt(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None,
|
||||
):
|
||||
if (
|
||||
loaded_shard_id in ["k", "v"]
|
||||
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
|
||||
):
|
||||
loaded_weight = get_original_weight(loaded_weight, self.head_size)
|
||||
|
||||
weight_loader(self, param, loaded_weight, loaded_shard_id)
|
||||
|
||||
return weight_loader_srt
|
||||
|
||||
|
||||
# Monkey patch model loader
|
||||
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
||||
original_weight_loader = QKVParallelLinear.weight_loader
|
||||
setattr(
|
||||
QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user