Support DeepSeek V3.2 Exp (#11061)
Co-authored-by: Stefan He <11166516+hebiao064@users.noreply.github.com> Co-authored-by: Liangsheng Yin <95566987+hnyls2002@users.noreply.github.com> Co-authored-by: Baizhou Zhang <56809903+fridge003@users.noreply.github.com> Co-authored-by: DarkSharpness <76582120+darksharpness@users.noreply.github.com> Co-authored-by: ZhengdQin <46387172+zhengdqin@users.noreply.github.com> Co-authored-by: DarkSharpness <2040703891@qq.com> Co-authored-by: hnyls2002 <lsyincs@gmail.com> Co-authored-by: Zhengda Qin <zhengdqin@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: HAI <hixiao@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# Adapted from:
|
||||
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
||||
"""Inference-only DeepseekV2 model."""
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
@@ -25,10 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
from sglang.srt.configs.model_config import (
|
||||
get_nsa_index_head_dim,
|
||||
get_nsa_index_n_heads,
|
||||
get_nsa_index_topk,
|
||||
is_deepseek_nsa,
|
||||
)
|
||||
from sglang.srt.debug_utils.dumper import dumper
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
@@ -48,6 +55,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
||||
NPUFusedMLAPreprocess,
|
||||
is_mla_preprocess_enabled,
|
||||
)
|
||||
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
|
||||
from sglang.srt.layers.communicator import (
|
||||
LayerCommunicator,
|
||||
LayerScatterModes,
|
||||
@@ -172,10 +180,13 @@ elif _is_hip:
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton as awq_dequantize,
|
||||
)
|
||||
elif _is_npu:
|
||||
import custom_ops
|
||||
import sgl_kernel_npu
|
||||
import torch_npu
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
@@ -184,6 +195,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
||||
"fa3",
|
||||
"nsa",
|
||||
"flashinfer",
|
||||
"cutlass_mla",
|
||||
"trtllm_mla",
|
||||
@@ -204,6 +216,9 @@ class AttnForwardMethod(IntEnum):
|
||||
# Use absorbed multi-latent attention
|
||||
MLA = auto()
|
||||
|
||||
# Use Deepseek V3.2 sparse multi-latent attention
|
||||
NPU_MLA_SPARSE = auto()
|
||||
|
||||
# Use multi-head attention, but with KV cache chunked.
|
||||
# This method can avoid OOM when prefix lengths are long.
|
||||
MHA_CHUNKED_KV = auto()
|
||||
@@ -246,9 +261,15 @@ def handle_attention_ascend(attn, forward_batch):
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
if hasattr(attn, "indexer"):
|
||||
return AttnForwardMethod.NPU_MLA_SPARSE
|
||||
else:
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
if hasattr(attn, "indexer"):
|
||||
return AttnForwardMethod.NPU_MLA_SPARSE
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def _get_sum_extend_prefix_lens(forward_batch):
|
||||
@@ -267,7 +288,9 @@ def _is_extend_without_speculative(forward_batch):
|
||||
)
|
||||
|
||||
|
||||
def _handle_attention_backend(attn, forward_batch, backend_name):
|
||||
def _handle_attention_backend(
|
||||
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
|
||||
):
|
||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||
disable_ragged = (
|
||||
backend_name in ["flashinfer", "flashmla"]
|
||||
@@ -333,6 +356,10 @@ def handle_attention_aiter(attn, forward_batch):
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def handle_attention_nsa(attn, forward_batch):
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def handle_attention_triton(attn, forward_batch):
|
||||
if (
|
||||
_is_extend_without_speculative(forward_batch)
|
||||
@@ -1005,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
# For tensor parallel attention
|
||||
if self.q_lora_rank is not None:
|
||||
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
|
||||
@@ -1042,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
||||
)
|
||||
|
||||
self.use_nsa = is_deepseek_nsa(config)
|
||||
if self.use_nsa:
|
||||
self.indexer = Indexer(
|
||||
hidden_size=hidden_size,
|
||||
index_n_heads=get_nsa_index_n_heads(config),
|
||||
index_head_dim=get_nsa_index_head_dim(config),
|
||||
rope_head_dim=qk_rope_head_dim,
|
||||
index_topk=get_nsa_index_topk(config),
|
||||
q_lora_rank=q_lora_rank,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
scale_fmt="ue8m0",
|
||||
block_size=128,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=add_prefix("indexer", prefix),
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
alt_stream=alt_stream,
|
||||
)
|
||||
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
@@ -1064,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
@@ -1193,8 +1241,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
|
||||
if self.is_mla_preprocess_enabled:
|
||||
assert (
|
||||
quant_config.get_name() == "w8a8_int8"
|
||||
), "MLA Preprocess only works with W8A8Int8"
|
||||
quant_config is None or quant_config.get_name() == "w8a8_int8"
|
||||
), "MLA Preprocess only works with Unquant or W8A8Int8"
|
||||
self.mla_preprocess = None
|
||||
|
||||
def dispatch_attn_forward_method(
|
||||
@@ -1272,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return hidden_states, None, forward_batch, None
|
||||
|
||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||
|
||||
if attn_forward_method == AttnForwardMethod.MHA:
|
||||
inner_state = self.forward_normal_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
@@ -1304,6 +1351,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
inner_state = self.mla_preprocess.forward(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
||||
inner_state = self.forward_npu_sparse_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
@@ -1329,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return self.forward_normal_chunked_kv_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
return self.forward_absorb_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
||||
return self.forward_npu_sparse_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
||||
@@ -1424,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
q_lora = None
|
||||
if self.q_lora_rank is not None:
|
||||
if (
|
||||
(not isinstance(hidden_states, tuple))
|
||||
@@ -1462,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
|
||||
# q_lora needed by indexer
|
||||
if self.use_nsa:
|
||||
q_lora = q
|
||||
|
||||
k_nope = k_nope.unsqueeze(1)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
@@ -1527,14 +1585,41 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
||||
not _use_aiter or not _is_gfx95_supported
|
||||
not _use_aiter or not _is_gfx95_supported or self.use_nsa
|
||||
):
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||
topk_indices = None
|
||||
if q_lora is not None:
|
||||
topk_indices = self.indexer(
|
||||
x=hidden_states,
|
||||
q_lora=q_lora,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
layer_id=self.layer_id,
|
||||
)
|
||||
|
||||
return (
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
topk_indices,
|
||||
)
|
||||
|
||||
def forward_absorb_core(
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||
self,
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
topk_indices,
|
||||
):
|
||||
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
||||
extra_args = {}
|
||||
@@ -1543,6 +1628,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
||||
"is_neox": self.rotary_emb.is_neox_style,
|
||||
}
|
||||
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
@@ -1551,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_rope=q_pe,
|
||||
k_rope=k_pe,
|
||||
**extra_args,
|
||||
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
||||
)
|
||||
else:
|
||||
if _use_aiter_gfx95:
|
||||
@@ -1570,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
|
||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||
attn_output = self.attn_mqa(
|
||||
q,
|
||||
k,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
||||
)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
@@ -1652,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward_npu_sparse_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
"""
|
||||
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
|
||||
"""
|
||||
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
|
||||
if self.mla_preprocess is None:
|
||||
self.mla_preprocess = NPUFusedMLAPreprocess(
|
||||
self.fused_qkv_a_proj_with_mqa,
|
||||
self.q_a_layernorm,
|
||||
self.kv_a_layernorm,
|
||||
self.q_b_proj,
|
||||
self.w_kc,
|
||||
self.rotary_emb,
|
||||
self.layer_id,
|
||||
self.num_local_heads,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
)
|
||||
(
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
) = self.mla_preprocess.forward(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
|
||||
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
||||
q, _ = fused_qkv_a_proj_out.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q_lora = self.q_a_layernorm(q)
|
||||
else:
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if (
|
||||
(not isinstance(hidden_states, tuple))
|
||||
and hidden_states.shape[0] <= 16
|
||||
and self.use_min_latency_fused_a_gemm
|
||||
):
|
||||
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
||||
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
||||
)
|
||||
else:
|
||||
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
||||
q, latent_cache = fused_qkv_a_proj_out.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
|
||||
# overlap qk norm
|
||||
if self.alt_stream is not None and get_is_capture_mode():
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
q = self.q_a_layernorm(q)
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
||||
q, k_nope = fused_rms_mxfp4_quant(
|
||||
q,
|
||||
self.q_a_layernorm.weight,
|
||||
self.q_a_layernorm.variance_epsilon,
|
||||
k_nope,
|
||||
self.kv_a_layernorm.weight,
|
||||
self.kv_a_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
|
||||
q_lora = q.clone() # required for topk_indices
|
||||
k_nope = k_nope.unsqueeze(1)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
|
||||
q_nope, q_pe = q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||
|
||||
if self.use_deep_gemm_bmm:
|
||||
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
||||
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
||||
q_nope.transpose(0, 1)
|
||||
)
|
||||
)
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
(q_nope_val, q_nope_scale),
|
||||
(self.w_kc, self.w_scale_k),
|
||||
q_nope_out,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
q_nope_out = q_nope_out[:, :expected_m, :]
|
||||
elif _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
||||
x = q_nope.transpose(0, 1)
|
||||
q_nope_out = torch.empty(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
self.w_kc.shape[2],
|
||||
device=x.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
batched_gemm_afp4wfp4_pre_quant(
|
||||
x,
|
||||
self.w_kc.transpose(-2, -1),
|
||||
self.w_scale_k.transpose(-2, -1),
|
||||
torch.bfloat16,
|
||||
q_nope_out,
|
||||
)
|
||||
else:
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1),
|
||||
zero_allocator.allocate(1),
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
)
|
||||
else:
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
||||
not _use_aiter or not _is_gfx95_supported
|
||||
):
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
# TODO: multi-stream indexer
|
||||
topk_indices = self.indexer(
|
||||
hidden_states, q_lora, positions, forward_batch, self.layer_id
|
||||
)
|
||||
|
||||
return (
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
topk_indices,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
)
|
||||
|
||||
def forward_npu_sparse_core(
|
||||
self,
|
||||
q_pe,
|
||||
k_pe,
|
||||
q_nope_out,
|
||||
k_nope,
|
||||
topk_indices,
|
||||
forward_batch,
|
||||
zero_allocator,
|
||||
positions,
|
||||
):
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out.contiguous(),
|
||||
k_nope.contiguous(),
|
||||
k_nope.contiguous(),
|
||||
forward_batch,
|
||||
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
|
||||
q_rope=q_pe.contiguous(),
|
||||
k_rope=k_pe.contiguous(),
|
||||
topk_indices=topk_indices,
|
||||
)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
attn_bmm_output = torch.empty(
|
||||
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
|
||||
dtype=attn_output.dtype,
|
||||
device=attn_output.device,
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_decode():
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
torch.bmm(
|
||||
attn_output,
|
||||
self.w_vc,
|
||||
out=attn_bmm_output.view(
|
||||
-1, self.num_local_heads, self.v_head_dim
|
||||
).transpose(0, 1),
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.contiguous()
|
||||
torch.ops.npu.batch_matmul_transpose(
|
||||
attn_output, self.w_vc, attn_bmm_output
|
||||
)
|
||||
|
||||
attn_bmm_output = attn_bmm_output.reshape(
|
||||
-1, self.num_local_heads * self.v_head_dim
|
||||
)
|
||||
|
||||
output, _ = self.o_proj(attn_bmm_output)
|
||||
return output
|
||||
|
||||
def forward_absorb_fused_mla_rope_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -2134,7 +2442,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
zero_allocator: BumpAllocator,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
quant_format = (
|
||||
"mxfp4"
|
||||
if _is_gfx95_supported
|
||||
@@ -3099,6 +3406,7 @@ AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
||||
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
||||
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
||||
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
|
||||
AttentionBackendRegistry.register("nsa", handle_attention_nsa)
|
||||
AttentionBackendRegistry.register("triton", handle_attention_triton)
|
||||
|
||||
|
||||
@@ -3106,4 +3414,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
|
||||
class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
|
||||
|
||||
Reference in New Issue
Block a user