@@ -1678,9 +1678,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
||||
self.attn_mha.layer_id
|
||||
)
|
||||
latent_cache = latent_cache_buf[
|
||||
forward_batch.prefix_chunk_kv_indices[i]
|
||||
].contiguous()
|
||||
latent_cache = (
|
||||
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
|
||||
.contiguous()
|
||||
.to(q.dtype)
|
||||
)
|
||||
|
||||
kv_a_normed, k_pe = latent_cache.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
|
||||
Reference in New Issue
Block a user