[Feature] support deepseek v3/r1/v3.2 (#78)
* [Feature] support deepseek v3/r1/v3.2 * fix gpt_oss * update readme * update readme --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
@@ -1493,4 +1493,45 @@ def _fake_gptq_shuffle(
|
||||
return None
|
||||
|
||||
|
||||
gptq_shuffle.register_fake(_fake_gptq_shuffle)
|
||||
gptq_shuffle.register_fake(_fake_gptq_shuffle)
|
||||
|
||||
##################################################
|
||||
# ---------------- concat_and_cache_mla ------------------
|
||||
##################################################
|
||||
@custom_op("_C::concat_and_cache_mla", mutates_args=())
|
||||
def concat_and_cache_mla(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
@impl("_C::concat_and_cache_mla", "CUDA")
|
||||
def concat_and_cache_mla_cuda(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
def _fake_concat_and_cache_mla(
|
||||
kv_c: torch.Tensor, #[num_tokens, kv_lora_rank]
|
||||
k_pe: torch.Tensor, #[num_tokens, pe_dim]
|
||||
kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla)
|
||||
Reference in New Issue
Block a user