[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)

* [Kernel] add kernels to torch.ops

* [Misc]Specify that DS only supports --kv-cache-dtype bfloat16

---------

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-17 16:52:02 +08:00
committed by GitHub
parent 8988ad08b2
commit 71a5a04e0a
4 changed files with 325 additions and 15 deletions

View File

@@ -731,22 +731,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = torch.cat([ql_nope, q_pe], dim=-1)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
torch.ops._C.concat_and_cache_mla(
kv_c=k_c_normed,
k_pe=k_pe.squeeze(1),
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping.flatten(),
)
if self.kv_cache_dtype != "fp8_ds_mla":
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
torch.ops._C.concat_and_cache_mla(
kv_c=k_c_normed,
k_pe=k_pe.squeeze(1),
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping.flatten(),
)
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
attn_metadata)
else:
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
# attn_metadata)
raise NotImplementedError
raise NotImplementedError("Only support --kv-cache-dtype bfloat16")
self._v_up_proj(attn_out, out=output[:num_actual_toks])
return output