fix: fp8 config (#723)
This commit is contained in:
@@ -15,6 +15,7 @@ from flashinfer import (
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
from vllm.config import DeviceConfig, LoadConfig
|
from vllm.config import DeviceConfig, LoadConfig
|
||||||
from vllm.config import ModelConfig as VllmModelConfig
|
from vllm.config import ModelConfig as VllmModelConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
@@ -22,6 +23,7 @@ from vllm.distributed import (
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
@@ -38,6 +40,18 @@ from sglang.srt.utils import (
|
|||||||
logger = logging.getLogger("srt.model_runner")
|
logger = logging.getLogger("srt.model_runner")
|
||||||
|
|
||||||
|
|
||||||
|
def is_llama3_405b_fp8(model_config):
|
||||||
|
if (
|
||||||
|
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
||||||
|
and model_config.hf_config.hidden_size == 16384
|
||||||
|
and model_config.hf_config.intermediate_size == 53248
|
||||||
|
and model_config.hf_config.num_hidden_layers == 126
|
||||||
|
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -118,6 +132,9 @@ class ModelRunner:
|
|||||||
seed=42,
|
seed=42,
|
||||||
skip_tokenizer_init=True,
|
skip_tokenizer_init=True,
|
||||||
)
|
)
|
||||||
|
if is_llama3_405b_fp8(self.model_config):
|
||||||
|
self.model_config.hf_config.num_key_value_heads = 8
|
||||||
|
vllm_model_config.hf_config.num_key_value_heads = 8
|
||||||
self.dtype = vllm_model_config.dtype
|
self.dtype = vllm_model_config.dtype
|
||||||
if self.model_config.model_overide_args is not None:
|
if self.model_config.model_overide_args is not None:
|
||||||
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
||||||
@@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
|||||||
return model_arch_name_to_cls[model_arch]
|
return model_arch_name_to_cls[model_arch]
|
||||||
|
|
||||||
|
|
||||||
|
def get_original_weight(loaded_weight, head_dim):
|
||||||
|
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
|
||||||
|
dim = loaded_weight.shape[1]
|
||||||
|
for i in range(n_kv_head):
|
||||||
|
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
|
||||||
|
2 * i * head_dim : (2 * i + 1) * head_dim, :
|
||||||
|
]
|
||||||
|
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
|
||||||
|
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
|
||||||
|
return original_kv_weight
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_loader_srt(weight_loader):
|
||||||
|
def weight_loader_srt(
|
||||||
|
self,
|
||||||
|
param: Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
loaded_shard_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
loaded_shard_id in ["k", "v"]
|
||||||
|
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
|
||||||
|
):
|
||||||
|
loaded_weight = get_original_weight(loaded_weight, self.head_size)
|
||||||
|
|
||||||
|
weight_loader(self, param, loaded_weight, loaded_shard_id)
|
||||||
|
|
||||||
|
return weight_loader_srt
|
||||||
|
|
||||||
|
|
||||||
# Monkey patch model loader
|
# Monkey patch model loader
|
||||||
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
||||||
|
original_weight_loader = QKVParallelLinear.weight_loader
|
||||||
|
setattr(
|
||||||
|
QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user