[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)
* [Kernel] add kernels to torch.ops * [Misc]Specify that DS only supports --kv-cache-dtype bfloat16 --------- Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -731,22 +731,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
torch.ops._C.concat_and_cache_mla(
|
||||
kv_c=k_c_normed,
|
||||
k_pe=k_pe.squeeze(1),
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
torch.ops._C.concat_and_cache_mla(
|
||||
kv_c=k_c_normed,
|
||||
k_pe=k_pe.squeeze(1),
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
||||
attn_metadata)
|
||||
else:
|
||||
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
|
||||
# attn_metadata)
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Only support --kv-cache-dtype bfloat16")
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user