[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)
* [Kernel] add kernels to torch.ops * [Misc]Specify that DS only supports --kv-cache-dtype bfloat16 --------- Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -171,7 +171,7 @@ def kunlun_flash_mla_with_kvcache(
|
||||
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||
dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
torch.ops._C.fwd_kvcache_mla(
|
||||
q_c=q,
|
||||
kv_cache=k_cache,
|
||||
indices=indices,
|
||||
@@ -224,7 +224,7 @@ def flash_mla_sparse_prefill(
|
||||
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
torch.ops._C.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import xtorch_ops
|
||||
|
||||
def int8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
@@ -29,7 +28,7 @@ def int8_mqa_logits(
|
||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
torch.ops._C.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=kv,
|
||||
weights=weights,
|
||||
@@ -99,7 +98,7 @@ def int8_paged_mqa_logits(
|
||||
|
||||
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
|
||||
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
torch.ops._C.I8_paged_mqa_logits(
|
||||
q=q_fp8,
|
||||
fused_kv_cache=kv_cache,
|
||||
weights=weights,
|
||||
|
||||
@@ -731,22 +731,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
torch.ops._C.concat_and_cache_mla(
|
||||
kv_c=k_c_normed,
|
||||
k_pe=k_pe.squeeze(1),
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
torch.ops._C.concat_and_cache_mla(
|
||||
kv_c=k_c_normed,
|
||||
k_pe=k_pe.squeeze(1),
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||
)
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
||||
attn_metadata)
|
||||
else:
|
||||
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
|
||||
# attn_metadata)
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Only support --kv-cache-dtype bfloat16")
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
|
||||
@@ -1923,3 +1923,315 @@ def apply_repetition_penalties_(
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
##################################################
|
||||
# --------------- I8_mqa_logits -----------------
|
||||
##################################################
|
||||
@custom_op("_C::I8_mqa_logits", mutates_args=())
|
||||
def I8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_q_lens: List[torch.Tensor],
|
||||
context_k_lens: List[torch.Tensor],
|
||||
logits: torch.Tensor,
|
||||
clean_logits: bool,
|
||||
max_seq_q: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
context_q_lens=context_q_lens,
|
||||
context_k_lens=context_k_lens,
|
||||
logits=logits,
|
||||
clean_logits=clean_logits,
|
||||
max_seq_q=max_seq_q,
|
||||
max_seq_k=max_seq_k,
|
||||
is_causal=is_causal,
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
)
|
||||
return None
|
||||
|
||||
@impl("_C::I8_mqa_logits", "CUDA")
|
||||
def I8_mqa_logits_cuda(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_q_lens: List[torch.Tensor],
|
||||
context_k_lens: List[torch.Tensor],
|
||||
logits: torch.Tensor,
|
||||
clean_logits: bool,
|
||||
max_seq_q: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
context_q_lens=context_q_lens,
|
||||
context_k_lens=context_k_lens,
|
||||
logits=logits,
|
||||
clean_logits=clean_logits,
|
||||
max_seq_q=max_seq_q,
|
||||
max_seq_k=max_seq_k,
|
||||
is_causal=is_causal,
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
)
|
||||
return None
|
||||
|
||||
def _fake_I8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_q_lens: List[torch.Tensor],
|
||||
context_k_lens: List[torch.Tensor],
|
||||
logits: torch.Tensor,
|
||||
clean_logits: bool,
|
||||
max_seq_q: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
I8_mqa_logits.register_fake(_fake_I8_mqa_logits)
|
||||
|
||||
##################################################
|
||||
# ------------- I8_paged_mqa_logits --------------
|
||||
##################################################
|
||||
@custom_op("_C::I8_paged_mqa_logits", mutates_args=())
|
||||
def I8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_lens: List[torch.Tensor],
|
||||
block_table: torch.Tensor,
|
||||
max_context_len: int,
|
||||
clean_logits: bool,
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
context_lens=context_lens,
|
||||
block_table=block_table,
|
||||
max_context_len=max_context_len,
|
||||
clean_logits=clean_logits,
|
||||
out=out,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
return None
|
||||
|
||||
@impl("_C::I8_paged_mqa_logits", "CUDA")
|
||||
def I8_paged_mqa_logits_cuda(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_lens: List[torch.Tensor],
|
||||
block_table: torch.Tensor,
|
||||
max_context_len: int,
|
||||
clean_logits: bool,
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
context_lens=context_lens,
|
||||
block_table=block_table,
|
||||
max_context_len=max_context_len,
|
||||
clean_logits=clean_logits,
|
||||
out=out,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
return None
|
||||
|
||||
def _fake_I8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_lens: List[torch.Tensor],
|
||||
block_table: torch.Tensor,
|
||||
max_context_len: int,
|
||||
clean_logits: bool,
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
return None
|
||||
|
||||
I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits)
|
||||
|
||||
##################################################
|
||||
# ----------- sparse_prefill_fwd_opt -------------
|
||||
##################################################
|
||||
@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=())
|
||||
def sparse_prefill_fwd_opt(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
lse=lse,
|
||||
sm_scale=sm_scale,
|
||||
qlod_cpu=qlod_cpu,
|
||||
qlod_xpu=qlod_xpu,
|
||||
kvlod_cpu=kvlod_cpu,
|
||||
kvlod_xpu=kvlod_xpu,
|
||||
d_v=d_v,
|
||||
is_causal=is_causal,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
return None
|
||||
|
||||
@impl("_C::sparse_prefill_fwd_opt", "CUDA")
|
||||
def sparse_prefill_fwd_opt_cuda(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
lse=lse,
|
||||
sm_scale=sm_scale,
|
||||
qlod_cpu=qlod_cpu,
|
||||
qlod_xpu=qlod_xpu,
|
||||
kvlod_cpu=kvlod_cpu,
|
||||
kvlod_xpu=kvlod_xpu,
|
||||
d_v=d_v,
|
||||
is_causal=is_causal,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
return None
|
||||
|
||||
def _fake_sparse_prefill_fwd_opt(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
return None
|
||||
|
||||
sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt)
|
||||
|
||||
##################################################
|
||||
# ------------------ fwd_kvcache_mla -------------
|
||||
##################################################
|
||||
@custom_op("_C::fwd_kvcache_mla", mutates_args=())
|
||||
def fwd_kvcache_mla(
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
indices=indices,
|
||||
kv_lod_cpu=kv_lod_cpu,
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
p_sums=p_sums,
|
||||
softmax_scale=softmax_scale,
|
||||
max_seq_kv=max_seq_kv,
|
||||
q_r=q_r,
|
||||
pe_cache=pe_cache,
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
kv_lod_xpu=kv_lod_xpu)
|
||||
return None
|
||||
|
||||
@impl("_C::fwd_kvcache_mla", "CUDA")
|
||||
def fwd_kvcache_mla_cuda(
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
indices=indices,
|
||||
kv_lod_cpu=kv_lod_cpu,
|
||||
out=out,
|
||||
max_logits=max_logits,
|
||||
p_sums=p_sums,
|
||||
softmax_scale=softmax_scale,
|
||||
max_seq_kv=max_seq_kv,
|
||||
q_r=q_r,
|
||||
pe_cache=pe_cache,
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
kv_lod_xpu=kv_lod_xpu)
|
||||
return None
|
||||
|
||||
def _fake_fwd_kvcache_mla(
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||
return None
|
||||
|
||||
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
||||
|
||||
Reference in New Issue
Block a user