Fix loading KV quantization scale; Enable modelopt kv cache (#4686)

Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
Yun Dai
2025-04-08 09:11:35 -07:00
committed by GitHub
parent 88d6fd9a11
commit 2695ab0537
38 changed files with 151 additions and 76 deletions

View File

@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
self.decode_cuda_graph_metadata = {}
self.target_verify_metadata = {}
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
self.page_size = model_runner.page_size
self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA
@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
if layer.sliding_window_size is not None
else (-1, -1)
)
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
causal = not layer.is_cross_attention
# Check if we should use local attention
@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
)
causal = not layer.is_cross_attention
k_descale, v_descale = None, None
if self.kv_cache_dtype_str != "auto":
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
if not self.use_mla:
# Do multi-head attention
@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
softmax_scale=layer.scaling,
causal=True,
softcap=layer.logit_cap,
k_descale=layer.k_scale,
v_descale=layer.v_scale,
k_descale=k_descale,
v_descale=v_descale,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

View File

@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
self.kv_cache_dtype = model_runner.kv_cache_dtype
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
assert not (
model_runner.sliding_window_size is not None
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer)
]
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
o = prefill_wrapper_paged.forward(
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
k_scale=k_scale,
v_scale=v_scale,
)
else:
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, k_scale, v_scale
)
o = decode_wrapper.forward(
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
k_scale=layer.k_scale,
v_scale=layer.v_scale,
k_scale=k_scale,
v_scale=v_scale,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)