Tiny cleanup deepseek_v2.py (#11163)
This commit is contained in:
@@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module):
|
||||
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
||||
self.dispatcher = StandardDispatcher()
|
||||
|
||||
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
|
||||
self.quant_method, ModelOptNvFp4FusedMoEMethod
|
||||
) or (
|
||||
isinstance(self.quant_method, Fp8MoEMethod)
|
||||
and self.quant_method.use_cutlass_fused_experts_fp8
|
||||
)
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
shard_id: str,
|
||||
@@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module):
|
||||
for shard_id in ["w1", "w2", "w3"]
|
||||
]
|
||||
|
||||
def should_fuse_routed_scaling_factor_in_topk(self):
|
||||
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
|
||||
isinstance(self.quant_method, Fp8MoEMethod)
|
||||
and self.quant_method.use_cutlass_fused_experts_fp8
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFusedMoE(FusedMoE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -166,16 +166,15 @@ if _is_cuda:
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
pass
|
||||
elif _is_hip:
|
||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
from sglang.srt.layers.quantization.awq_triton import (
|
||||
awq_dequantize_triton as awq_dequantize,
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
if _is_hip:
|
||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
@@ -229,7 +228,7 @@ def _dispatch_mla_subtype(attn, forward_batch):
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
class BackendRegistry:
|
||||
class AttentionBackendRegistry:
|
||||
_handlers = {}
|
||||
|
||||
@classmethod
|
||||
@@ -241,7 +240,7 @@ class BackendRegistry:
|
||||
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
||||
|
||||
|
||||
def handle_ascend(attn, forward_batch):
|
||||
def handle_attention_ascend(attn, forward_batch):
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
@@ -268,7 +267,7 @@ def _is_extend_without_speculative(forward_batch):
|
||||
)
|
||||
|
||||
|
||||
def _handle_backend(attn, forward_batch, backend_name):
|
||||
def _handle_attention_backend(attn, forward_batch, backend_name):
|
||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||
disable_ragged = (
|
||||
backend_name in ["flashinfer", "flashmla"]
|
||||
@@ -290,28 +289,28 @@ def _handle_backend(attn, forward_batch, backend_name):
|
||||
return _dispatch_mla_subtype(attn, forward_batch)
|
||||
|
||||
|
||||
def handle_flashinfer(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "flashinfer")
|
||||
def handle_attention_flashinfer(attn, forward_batch):
|
||||
return _handle_attention_backend(attn, forward_batch, "flashinfer")
|
||||
|
||||
|
||||
def handle_fa3(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "fa3")
|
||||
def handle_attention_fa3(attn, forward_batch):
|
||||
return _handle_attention_backend(attn, forward_batch, "fa3")
|
||||
|
||||
|
||||
def handle_flashmla(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "flashmla")
|
||||
def handle_attention_flashmla(attn, forward_batch):
|
||||
return _handle_attention_backend(attn, forward_batch, "flashmla")
|
||||
|
||||
|
||||
def handle_cutlass_mla(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "cutlass_mla")
|
||||
def handle_attention_cutlass_mla(attn, forward_batch):
|
||||
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
|
||||
|
||||
|
||||
def handle_fa4(attn, forward_batch):
|
||||
def handle_attention_fa4(attn, forward_batch):
|
||||
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
|
||||
|
||||
def handle_trtllm_mla(attn, forward_batch):
|
||||
def handle_attention_trtllm_mla(attn, forward_batch):
|
||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||
if _is_extend_without_speculative(forward_batch) and (
|
||||
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
||||
@@ -321,7 +320,7 @@ def handle_trtllm_mla(attn, forward_batch):
|
||||
return _dispatch_mla_subtype(attn, forward_batch)
|
||||
|
||||
|
||||
def handle_aiter(attn, forward_batch):
|
||||
def handle_attention_aiter(attn, forward_batch):
|
||||
if _is_extend_without_speculative(forward_batch):
|
||||
if is_dp_attention_enabled():
|
||||
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
||||
@@ -334,7 +333,7 @@ def handle_aiter(attn, forward_batch):
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def handle_triton(attn, forward_batch):
|
||||
def handle_attention_triton(attn, forward_batch):
|
||||
if (
|
||||
_is_extend_without_speculative(forward_batch)
|
||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||
@@ -541,7 +540,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
quant_config=quant_config,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
|
||||
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
||||
# and requires the output format to be standard. We use quant_config to determine the output format.
|
||||
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
||||
@@ -838,13 +837,13 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
if shared_output is not None:
|
||||
x = shared_output
|
||||
if self.experts.should_fuse_routed_scaling_factor_in_topk():
|
||||
if self.experts.should_fuse_routed_scaling_factor_in_topk:
|
||||
x.add_(final_hidden_states)
|
||||
else:
|
||||
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||
final_hidden_states = x
|
||||
else:
|
||||
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
|
||||
if not self.experts.should_fuse_routed_scaling_factor_in_topk:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
return final_hidden_states
|
||||
@@ -1217,7 +1216,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||
self.current_attention_backend = attention_backend
|
||||
|
||||
handler = BackendRegistry.get_handler(attention_backend)
|
||||
handler = AttentionBackendRegistry.get_handler(attention_backend)
|
||||
return handler(self, forward_batch)
|
||||
|
||||
def op_prepare(self, state):
|
||||
@@ -3092,15 +3091,15 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
BackendRegistry.register("ascend", handle_ascend)
|
||||
BackendRegistry.register("flashinfer", handle_flashinfer)
|
||||
BackendRegistry.register("fa3", handle_fa3)
|
||||
BackendRegistry.register("flashmla", handle_flashmla)
|
||||
BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
|
||||
BackendRegistry.register("fa4", handle_fa4)
|
||||
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
|
||||
BackendRegistry.register("aiter", handle_aiter)
|
||||
BackendRegistry.register("triton", handle_triton)
|
||||
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
|
||||
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|
||||
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
|
||||
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
|
||||
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
||||
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
||||
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
||||
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
|
||||
AttentionBackendRegistry.register("triton", handle_attention_triton)
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
|
||||
Reference in New Issue
Block a user