[Misc]Specify that DS32 only supports --kv-cache-dtype bfloat16 (#119)
* [Kernel] add kernels to torch.ops * [Misc]Specify that DS only supports --kv-cache-dtype bfloat16 --------- Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -171,7 +171,7 @@ def kunlun_flash_mla_with_kvcache(
|
|||||||
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
p_sums = torch.zeros([batch_size, seq_len_q, num_heads_q],
|
||||||
dtype=torch.float32, device=q.device)
|
dtype=torch.float32, device=q.device)
|
||||||
|
|
||||||
xtorch_ops.fwd_kvcache_mla(
|
torch.ops._C.fwd_kvcache_mla(
|
||||||
q_c=q,
|
q_c=q,
|
||||||
kv_cache=k_cache,
|
kv_cache=k_cache,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
@@ -224,7 +224,7 @@ def flash_mla_sparse_prefill(
|
|||||||
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
max_logits = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||||
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
lse = torch.zeros([s_q, h_q], dtype=torch.float32, device=q.device)
|
||||||
|
|
||||||
xtorch_ops.sparse_prefill_fwd_opt(
|
torch.ops._C.sparse_prefill_fwd_opt(
|
||||||
q=q,
|
q=q,
|
||||||
kv=kv,
|
kv=kv,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import xtorch_ops
|
|
||||||
|
|
||||||
def int8_mqa_logits(
|
def int8_mqa_logits(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
@@ -29,7 +28,7 @@ def int8_mqa_logits(
|
|||||||
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
context_q_lens_xpu = torch.tensor([0, q.shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||||
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
context_k_lens_xpu = torch.tensor([0, kv[0].shape[0]], dtype=torch.int32, device=cu_seqlen_ks.device)
|
||||||
|
|
||||||
xtorch_ops.I8_mqa_logits(
|
torch.ops._C.I8_mqa_logits(
|
||||||
q=q,
|
q=q,
|
||||||
fused_kv_cache=kv,
|
fused_kv_cache=kv,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@@ -99,7 +98,7 @@ def int8_paged_mqa_logits(
|
|||||||
|
|
||||||
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
|
logits = torch.empty((batch_size, next_n, max_model_len), dtype=torch.float32, device=q_fp8.device)
|
||||||
|
|
||||||
xtorch_ops.I8_paged_mqa_logits(
|
torch.ops._C.I8_paged_mqa_logits(
|
||||||
q=q_fp8,
|
q=q_fp8,
|
||||||
fused_kv_cache=kv_cache,
|
fused_kv_cache=kv_cache,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
|||||||
@@ -731,22 +731,21 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
|||||||
|
|
||||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||||
|
|
||||||
# write the latent and rope to kv cache
|
|
||||||
if kv_cache.numel() > 0:
|
|
||||||
torch.ops._C.concat_and_cache_mla(
|
|
||||||
kv_c=k_c_normed,
|
|
||||||
k_pe=k_pe.squeeze(1),
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||||
|
# write the latent and rope to kv cache
|
||||||
|
if kv_cache.numel() > 0:
|
||||||
|
torch.ops._C.concat_and_cache_mla(
|
||||||
|
kv_c=k_c_normed,
|
||||||
|
k_pe=k_pe.squeeze(1),
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
||||||
|
)
|
||||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
else:
|
else:
|
||||||
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
|
# attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
|
||||||
# attn_metadata)
|
# attn_metadata)
|
||||||
raise NotImplementedError
|
raise NotImplementedError("Only support --kv-cache-dtype bfloat16")
|
||||||
|
|
||||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -1923,3 +1923,315 @@ def apply_repetition_penalties_(
|
|||||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||||
logits *= scaling
|
logits *= scaling
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# --------------- I8_mqa_logits -----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::I8_mqa_logits", mutates_args=())
|
||||||
|
def I8_mqa_logits(
|
||||||
|
q: torch.Tensor,
|
||||||
|
fused_kv_cache: List[torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_q_lens: List[torch.Tensor],
|
||||||
|
context_k_lens: List[torch.Tensor],
|
||||||
|
logits: torch.Tensor,
|
||||||
|
clean_logits: bool,
|
||||||
|
max_seq_q: Optional[int] = 0,
|
||||||
|
max_seq_k: Optional[int] = 0,
|
||||||
|
is_causal: Optional[bool] = False,
|
||||||
|
use_xfa_boost: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.I8_mqa_logits(
|
||||||
|
q=q,
|
||||||
|
fused_kv_cache=fused_kv_cache,
|
||||||
|
weights=weights,
|
||||||
|
context_q_lens=context_q_lens,
|
||||||
|
context_k_lens=context_k_lens,
|
||||||
|
logits=logits,
|
||||||
|
clean_logits=clean_logits,
|
||||||
|
max_seq_q=max_seq_q,
|
||||||
|
max_seq_k=max_seq_k,
|
||||||
|
is_causal=is_causal,
|
||||||
|
use_xfa_boost=use_xfa_boost,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@impl("_C::I8_mqa_logits", "CUDA")
|
||||||
|
def I8_mqa_logits_cuda(
|
||||||
|
q: torch.Tensor,
|
||||||
|
fused_kv_cache: List[torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_q_lens: List[torch.Tensor],
|
||||||
|
context_k_lens: List[torch.Tensor],
|
||||||
|
logits: torch.Tensor,
|
||||||
|
clean_logits: bool,
|
||||||
|
max_seq_q: Optional[int] = 0,
|
||||||
|
max_seq_k: Optional[int] = 0,
|
||||||
|
is_causal: Optional[bool] = False,
|
||||||
|
use_xfa_boost: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.I8_mqa_logits(
|
||||||
|
q=q,
|
||||||
|
fused_kv_cache=fused_kv_cache,
|
||||||
|
weights=weights,
|
||||||
|
context_q_lens=context_q_lens,
|
||||||
|
context_k_lens=context_k_lens,
|
||||||
|
logits=logits,
|
||||||
|
clean_logits=clean_logits,
|
||||||
|
max_seq_q=max_seq_q,
|
||||||
|
max_seq_k=max_seq_k,
|
||||||
|
is_causal=is_causal,
|
||||||
|
use_xfa_boost=use_xfa_boost,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fake_I8_mqa_logits(
|
||||||
|
q: torch.Tensor,
|
||||||
|
fused_kv_cache: List[torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_q_lens: List[torch.Tensor],
|
||||||
|
context_k_lens: List[torch.Tensor],
|
||||||
|
logits: torch.Tensor,
|
||||||
|
clean_logits: bool,
|
||||||
|
max_seq_q: Optional[int] = 0,
|
||||||
|
max_seq_k: Optional[int] = 0,
|
||||||
|
is_causal: Optional[bool] = False,
|
||||||
|
use_xfa_boost: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
I8_mqa_logits.register_fake(_fake_I8_mqa_logits)
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ------------- I8_paged_mqa_logits --------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::I8_paged_mqa_logits", mutates_args=())
|
||||||
|
def I8_paged_mqa_logits(
|
||||||
|
q: torch.Tensor,
|
||||||
|
fused_kv_cache: List[torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_lens: List[torch.Tensor],
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
max_context_len: int,
|
||||||
|
clean_logits: bool,
|
||||||
|
out: torch.Tensor,
|
||||||
|
use_xfa_boost: Optional[bool] = False) -> None:
|
||||||
|
xtorch_ops.I8_paged_mqa_logits(
|
||||||
|
q=q,
|
||||||
|
fused_kv_cache=fused_kv_cache,
|
||||||
|
weights=weights,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_table=block_table,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
clean_logits=clean_logits,
|
||||||
|
out=out,
|
||||||
|
use_xfa_boost=use_xfa_boost)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@impl("_C::I8_paged_mqa_logits", "CUDA")
|
||||||
|
def I8_paged_mqa_logits_cuda(
|
||||||
|
q: torch.Tensor,
|
||||||
|
fused_kv_cache: List[torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_lens: List[torch.Tensor],
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
max_context_len: int,
|
||||||
|
clean_logits: bool,
|
||||||
|
out: torch.Tensor,
|
||||||
|
use_xfa_boost: Optional[bool] = False) -> None:
|
||||||
|
xtorch_ops.I8_paged_mqa_logits(
|
||||||
|
q=q,
|
||||||
|
fused_kv_cache=fused_kv_cache,
|
||||||
|
weights=weights,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_table=block_table,
|
||||||
|
max_context_len=max_context_len,
|
||||||
|
clean_logits=clean_logits,
|
||||||
|
out=out,
|
||||||
|
use_xfa_boost=use_xfa_boost)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fake_I8_paged_mqa_logits(
|
||||||
|
q: torch.Tensor,
|
||||||
|
fused_kv_cache: List[torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_lens: List[torch.Tensor],
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
max_context_len: int,
|
||||||
|
clean_logits: bool,
|
||||||
|
out: torch.Tensor,
|
||||||
|
use_xfa_boost: Optional[bool] = False) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits)
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------- sparse_prefill_fwd_opt -------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=())
|
||||||
|
def sparse_prefill_fwd_opt(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
lse: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
qlod_cpu: Optional[torch.Tensor] = None,
|
||||||
|
qlod_xpu: Optional[torch.Tensor] = None,
|
||||||
|
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||||
|
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||||
|
d_v: Optional[int] = -1,
|
||||||
|
is_causal: Optional[bool] = True,
|
||||||
|
use_xfa_boost: Optional[bool] = False) -> None:
|
||||||
|
xtorch_ops.sparse_prefill_fwd_opt(
|
||||||
|
q=q,
|
||||||
|
kv=kv,
|
||||||
|
indices=indices,
|
||||||
|
out=out,
|
||||||
|
max_logits=max_logits,
|
||||||
|
lse=lse,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
qlod_cpu=qlod_cpu,
|
||||||
|
qlod_xpu=qlod_xpu,
|
||||||
|
kvlod_cpu=kvlod_cpu,
|
||||||
|
kvlod_xpu=kvlod_xpu,
|
||||||
|
d_v=d_v,
|
||||||
|
is_causal=is_causal,
|
||||||
|
use_xfa_boost=use_xfa_boost)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@impl("_C::sparse_prefill_fwd_opt", "CUDA")
|
||||||
|
def sparse_prefill_fwd_opt_cuda(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
lse: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
qlod_cpu: Optional[torch.Tensor] = None,
|
||||||
|
qlod_xpu: Optional[torch.Tensor] = None,
|
||||||
|
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||||
|
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||||
|
d_v: Optional[int] = -1,
|
||||||
|
is_causal: Optional[bool] = True,
|
||||||
|
use_xfa_boost: Optional[bool] = False) -> None:
|
||||||
|
xtorch_ops.sparse_prefill_fwd_opt(
|
||||||
|
q=q,
|
||||||
|
kv=kv,
|
||||||
|
indices=indices,
|
||||||
|
out=out,
|
||||||
|
max_logits=max_logits,
|
||||||
|
lse=lse,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
qlod_cpu=qlod_cpu,
|
||||||
|
qlod_xpu=qlod_xpu,
|
||||||
|
kvlod_cpu=kvlod_cpu,
|
||||||
|
kvlod_xpu=kvlod_xpu,
|
||||||
|
d_v=d_v,
|
||||||
|
is_causal=is_causal,
|
||||||
|
use_xfa_boost=use_xfa_boost)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fake_sparse_prefill_fwd_opt(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
lse: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
qlod_cpu: Optional[torch.Tensor] = None,
|
||||||
|
qlod_xpu: Optional[torch.Tensor] = None,
|
||||||
|
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||||
|
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||||
|
d_v: Optional[int] = -1,
|
||||||
|
is_causal: Optional[bool] = True,
|
||||||
|
use_xfa_boost: Optional[bool] = False) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt)
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ------------------ fwd_kvcache_mla -------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::fwd_kvcache_mla", mutates_args=())
|
||||||
|
def fwd_kvcache_mla(
|
||||||
|
q_c: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
kv_lod_cpu: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
p_sums: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
max_seq_kv: int,
|
||||||
|
q_r: Optional[torch.Tensor] = None,
|
||||||
|
pe_cache: Optional[torch.Tensor] = None,
|
||||||
|
use_xfa_boost: Optional[bool] = False,
|
||||||
|
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||||
|
xtorch_ops.fwd_kvcache_mla(
|
||||||
|
q_c=q_c,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
indices=indices,
|
||||||
|
kv_lod_cpu=kv_lod_cpu,
|
||||||
|
out=out,
|
||||||
|
max_logits=max_logits,
|
||||||
|
p_sums=p_sums,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
max_seq_kv=max_seq_kv,
|
||||||
|
q_r=q_r,
|
||||||
|
pe_cache=pe_cache,
|
||||||
|
use_xfa_boost=use_xfa_boost,
|
||||||
|
kv_lod_xpu=kv_lod_xpu)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@impl("_C::fwd_kvcache_mla", "CUDA")
|
||||||
|
def fwd_kvcache_mla_cuda(
|
||||||
|
q_c: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
kv_lod_cpu: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
p_sums: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
max_seq_kv: int,
|
||||||
|
q_r: Optional[torch.Tensor] = None,
|
||||||
|
pe_cache: Optional[torch.Tensor] = None,
|
||||||
|
use_xfa_boost: Optional[bool] = False,
|
||||||
|
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||||
|
xtorch_ops.fwd_kvcache_mla(
|
||||||
|
q_c=q_c,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
indices=indices,
|
||||||
|
kv_lod_cpu=kv_lod_cpu,
|
||||||
|
out=out,
|
||||||
|
max_logits=max_logits,
|
||||||
|
p_sums=p_sums,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
max_seq_kv=max_seq_kv,
|
||||||
|
q_r=q_r,
|
||||||
|
pe_cache=pe_cache,
|
||||||
|
use_xfa_boost=use_xfa_boost,
|
||||||
|
kv_lod_xpu=kv_lod_xpu)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fake_fwd_kvcache_mla(
|
||||||
|
q_c: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
kv_lod_cpu: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
p_sums: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
max_seq_kv: int,
|
||||||
|
q_r: Optional[torch.Tensor] = None,
|
||||||
|
pe_cache: Optional[torch.Tensor] = None,
|
||||||
|
use_xfa_boost: Optional[bool] = False,
|
||||||
|
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
||||||
|
|||||||
Reference in New Issue
Block a user