Support FP8 E4M3 KV Cache (#2786)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -353,7 +353,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
@@ -362,6 +364,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
@@ -387,7 +391,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
@@ -412,13 +418,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
|
||||
self.logit_cap = logit_cap
|
||||
self.sliding_window_size = sliding_window_size or -1
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.k_scale = 1.0
|
||||
self.v_scale = 1.0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -109,8 +109,8 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
@@ -256,11 +256,13 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
layer_id = layer.layer_id
|
||||
if cache_k.dtype != self.dtype:
|
||||
cache_k = cache_k.to(self.dtype)
|
||||
cache_v = cache_v.to(self.dtype)
|
||||
cache_k = (cache_k / k_scale).to(self.dtype)
|
||||
cache_v = (cache_v / v_scale).to(self.dtype)
|
||||
if self.store_dtype != self.dtype:
|
||||
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
||||
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
||||
|
||||
@@ -54,6 +54,7 @@ from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
get_available_gpu_memory,
|
||||
init_custom_process_group,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
@@ -277,6 +278,29 @@ class ModelRunner:
|
||||
device_config=DeviceConfig(self.device),
|
||||
)
|
||||
|
||||
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if self.server_args.quantization_param_path is not None:
|
||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||
self.model.load_kv_cache_scales(
|
||||
self.server_args.quantization_param_path
|
||||
)
|
||||
logger.info(
|
||||
"Loaded KV cache scaling factors from %s",
|
||||
self.server_args.quantization_param_path,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Using FP8 KV cache and scaling factors provided but "
|
||||
"model %s does not support loading scaling factors.",
|
||||
self.model.__class__,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Using FP8 KV cache but no scaling factors "
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!"
|
||||
)
|
||||
|
||||
# Parse other args
|
||||
self.sliding_window_size = (
|
||||
self.model.get_attention_sliding_window_size()
|
||||
@@ -516,6 +540,9 @@ class ModelRunner:
|
||||
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
||||
else:
|
||||
self.kv_cache_dtype = torch.float8_e5m2
|
||||
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
||||
if is_cuda():
|
||||
self.kv_cache_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
||||
|
||||
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type,
|
||||
):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
|
||||
if hasattr(layer_self_attn.attn, "k_scale"):
|
||||
layer_self_attn.attn.k_scale = scaling_factor
|
||||
layer_self_attn.attn.v_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Self attention has no KV cache scaling " "factor attribute!"
|
||||
)
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
@@ -534,6 +562,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_ipv6,
|
||||
is_port_available,
|
||||
nullable_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,6 +48,7 @@ class ServerArgs:
|
||||
trust_remote_code: bool = True
|
||||
dtype: str = "auto"
|
||||
kv_cache_dtype: str = "auto"
|
||||
quantization_param_path: nullable_str = None
|
||||
quantization: Optional[str] = None
|
||||
context_length: Optional[int] = None
|
||||
device: str = "cuda"
|
||||
@@ -350,8 +352,17 @@ class ServerArgs:
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
default=ServerArgs.kv_cache_dtype,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
||||
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization-param-path",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="Path to the JSON file containing the KV cache "
|
||||
"scaling factors. This should generally be supplied, when "
|
||||
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
||||
"default to 1.0, which may cause accuracy issues. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
|
||||
@@ -1375,3 +1375,9 @@ def debug_timing(func):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def nullable_str(val: str):
|
||||
if not val or val == "None":
|
||||
return None
|
||||
return val
|
||||
|
||||
Reference in New Issue
Block a user