Add kernels to optimize RoPE and the decoding stage (#143)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -159,10 +159,8 @@ def kunlun_flash_mla_with_kvcache(
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
|
||||
q_r, pe_cache = None, None # 当q_r和pe_cache为空时,为packed模式
|
||||
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
|
||||
kv_lora_rank = head_dim_v
|
||||
rope_head_dim = head_dim - kv_lora_rank
|
||||
|
||||
out = torch.zeros([batch_size, seq_len_q, num_heads_q, kv_lora_rank],
|
||||
dtype=q.dtype, device=q.device)
|
||||
|
||||
Reference in New Issue
Block a user