Fix loading KV quantization scale; Enable modelopt kv cache (#4686)
Co-authored-by: qingquansong <ustcsqq@gmail.com>
This commit is contained in:
@@ -82,6 +82,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
@@ -391,6 +393,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
@@ -407,7 +411,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
@@ -417,8 +421,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
@@ -445,7 +449,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
@@ -459,6 +463,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
decode_wrapper = self.forward_metadata.decode_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
@@ -472,7 +478,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
)
|
||||
|
||||
o = decode_wrapper.forward(
|
||||
@@ -480,8 +486,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
Reference in New Issue
Block a user