Fix loading KV quantization scale; Enable modelopt kv cache (#4686)
Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
@@ -292,6 +292,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.target_verify_metadata = {}
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||||
self.page_size = model_runner.page_size
|
||||
self.use_mla = (
|
||||
model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -520,6 +522,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if layer.sliding_window_size is not None
|
||||
else (-1, -1)
|
||||
)
|
||||
k_descale, v_descale = None, None
|
||||
if self.kv_cache_dtype_str != "auto":
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
v_descale = layer.v_scale.expand(descale_shape)
|
||||
q = q.to(self.kv_cache_dtype)
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
# Check if we should use local attention
|
||||
@@ -576,8 +584,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
@@ -609,8 +617,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
@@ -657,6 +665,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
causal = not layer.is_cross_attention
|
||||
|
||||
k_descale, v_descale = None, None
|
||||
if self.kv_cache_dtype_str != "auto":
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
v_descale = layer.v_scale.expand(descale_shape)
|
||||
q = q.to(self.kv_cache_dtype)
|
||||
|
||||
if not self.use_mla:
|
||||
# Do multi-head attention
|
||||
|
||||
@@ -694,8 +709,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
@@ -729,8 +744,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=layer.k_scale,
|
||||
v_descale=layer.v_scale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
)
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user