Support FP8 E4M3 KV Cache (#2786)

Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
bjmsong
2025-01-13 13:17:11 +08:00
committed by GitHub
parent 85b2e05770
commit 0bb0f76311
9 changed files with 205 additions and 10 deletions

View File

@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
class LlamaForCausalLM(nn.Module):
@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
class Phi3ForCausalLM(LlamaForCausalLM):
pass