Tiny cleanup deepseek_v2.py (#11163)
This commit is contained in:
@@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
||||||
self.dispatcher = StandardDispatcher()
|
self.dispatcher = StandardDispatcher()
|
||||||
|
|
||||||
|
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
|
||||||
|
self.quant_method, ModelOptNvFp4FusedMoEMethod
|
||||||
|
) or (
|
||||||
|
isinstance(self.quant_method, Fp8MoEMethod)
|
||||||
|
and self.quant_method.use_cutlass_fused_experts_fp8
|
||||||
|
)
|
||||||
|
|
||||||
def _load_per_tensor_weight_scale(
|
def _load_per_tensor_weight_scale(
|
||||||
self,
|
self,
|
||||||
shard_id: str,
|
shard_id: str,
|
||||||
@@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
for shard_id in ["w1", "w2", "w3"]
|
for shard_id in ["w1", "w2", "w3"]
|
||||||
]
|
]
|
||||||
|
|
||||||
def should_fuse_routed_scaling_factor_in_topk(self):
|
|
||||||
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
|
|
||||||
isinstance(self.quant_method, Fp8MoEMethod)
|
|
||||||
and self.quant_method.use_cutlass_fused_experts_fp8
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferFusedMoE(FusedMoE):
|
class FlashInferFusedMoE(FusedMoE):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|||||||
@@ -166,16 +166,15 @@ if _is_cuda:
|
|||||||
elif _is_cpu and _is_cpu_amx_available:
|
elif _is_cpu and _is_cpu_amx_available:
|
||||||
pass
|
pass
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||||
|
decode_attention_fwd_grouped_rope,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.awq_triton import (
|
from sglang.srt.layers.quantization.awq_triton import (
|
||||||
awq_dequantize_triton as awq_dequantize,
|
awq_dequantize_triton as awq_dequantize,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if _is_hip:
|
|
||||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
|
||||||
decode_attention_fwd_grouped_rope,
|
|
||||||
)
|
|
||||||
|
|
||||||
_is_flashinfer_available = is_flashinfer_available()
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
@@ -229,7 +228,7 @@ def _dispatch_mla_subtype(attn, forward_batch):
|
|||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
|
|
||||||
|
|
||||||
class BackendRegistry:
|
class AttentionBackendRegistry:
|
||||||
_handlers = {}
|
_handlers = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -241,7 +240,7 @@ class BackendRegistry:
|
|||||||
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
||||||
|
|
||||||
|
|
||||||
def handle_ascend(attn, forward_batch):
|
def handle_attention_ascend(attn, forward_batch):
|
||||||
if (
|
if (
|
||||||
forward_batch.forward_mode.is_extend()
|
forward_batch.forward_mode.is_extend()
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
@@ -268,7 +267,7 @@ def _is_extend_without_speculative(forward_batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _handle_backend(attn, forward_batch, backend_name):
|
def _handle_attention_backend(attn, forward_batch, backend_name):
|
||||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||||
disable_ragged = (
|
disable_ragged = (
|
||||||
backend_name in ["flashinfer", "flashmla"]
|
backend_name in ["flashinfer", "flashmla"]
|
||||||
@@ -290,28 +289,28 @@ def _handle_backend(attn, forward_batch, backend_name):
|
|||||||
return _dispatch_mla_subtype(attn, forward_batch)
|
return _dispatch_mla_subtype(attn, forward_batch)
|
||||||
|
|
||||||
|
|
||||||
def handle_flashinfer(attn, forward_batch):
|
def handle_attention_flashinfer(attn, forward_batch):
|
||||||
return _handle_backend(attn, forward_batch, "flashinfer")
|
return _handle_attention_backend(attn, forward_batch, "flashinfer")
|
||||||
|
|
||||||
|
|
||||||
def handle_fa3(attn, forward_batch):
|
def handle_attention_fa3(attn, forward_batch):
|
||||||
return _handle_backend(attn, forward_batch, "fa3")
|
return _handle_attention_backend(attn, forward_batch, "fa3")
|
||||||
|
|
||||||
|
|
||||||
def handle_flashmla(attn, forward_batch):
|
def handle_attention_flashmla(attn, forward_batch):
|
||||||
return _handle_backend(attn, forward_batch, "flashmla")
|
return _handle_attention_backend(attn, forward_batch, "flashmla")
|
||||||
|
|
||||||
|
|
||||||
def handle_cutlass_mla(attn, forward_batch):
|
def handle_attention_cutlass_mla(attn, forward_batch):
|
||||||
return _handle_backend(attn, forward_batch, "cutlass_mla")
|
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
|
||||||
|
|
||||||
|
|
||||||
def handle_fa4(attn, forward_batch):
|
def handle_attention_fa4(attn, forward_batch):
|
||||||
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
|
|
||||||
|
|
||||||
def handle_trtllm_mla(attn, forward_batch):
|
def handle_attention_trtllm_mla(attn, forward_batch):
|
||||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||||
if _is_extend_without_speculative(forward_batch) and (
|
if _is_extend_without_speculative(forward_batch) and (
|
||||||
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
||||||
@@ -321,7 +320,7 @@ def handle_trtllm_mla(attn, forward_batch):
|
|||||||
return _dispatch_mla_subtype(attn, forward_batch)
|
return _dispatch_mla_subtype(attn, forward_batch)
|
||||||
|
|
||||||
|
|
||||||
def handle_aiter(attn, forward_batch):
|
def handle_attention_aiter(attn, forward_batch):
|
||||||
if _is_extend_without_speculative(forward_batch):
|
if _is_extend_without_speculative(forward_batch):
|
||||||
if is_dp_attention_enabled():
|
if is_dp_attention_enabled():
|
||||||
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
||||||
@@ -334,7 +333,7 @@ def handle_aiter(attn, forward_batch):
|
|||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
|
|
||||||
|
|
||||||
def handle_triton(attn, forward_batch):
|
def handle_attention_triton(attn, forward_batch):
|
||||||
if (
|
if (
|
||||||
_is_extend_without_speculative(forward_batch)
|
_is_extend_without_speculative(forward_batch)
|
||||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||||
@@ -541,7 +540,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
correction_bias=self.gate.e_score_correction_bias,
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
|
||||||
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
||||||
# and requires the output format to be standard. We use quant_config to determine the output format.
|
# and requires the output format to be standard. We use quant_config to determine the output format.
|
||||||
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
||||||
@@ -838,13 +837,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
x = shared_output
|
x = shared_output
|
||||||
if self.experts.should_fuse_routed_scaling_factor_in_topk():
|
if self.experts.should_fuse_routed_scaling_factor_in_topk:
|
||||||
x.add_(final_hidden_states)
|
x.add_(final_hidden_states)
|
||||||
else:
|
else:
|
||||||
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||||
final_hidden_states = x
|
final_hidden_states = x
|
||||||
else:
|
else:
|
||||||
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
|
if not self.experts.should_fuse_routed_scaling_factor_in_topk:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
@@ -1217,7 +1216,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||||
self.current_attention_backend = attention_backend
|
self.current_attention_backend = attention_backend
|
||||||
|
|
||||||
handler = BackendRegistry.get_handler(attention_backend)
|
handler = AttentionBackendRegistry.get_handler(attention_backend)
|
||||||
return handler(self, forward_batch)
|
return handler(self, forward_batch)
|
||||||
|
|
||||||
def op_prepare(self, state):
|
def op_prepare(self, state):
|
||||||
@@ -3092,15 +3091,15 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
BackendRegistry.register("ascend", handle_ascend)
|
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
|
||||||
BackendRegistry.register("flashinfer", handle_flashinfer)
|
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|
||||||
BackendRegistry.register("fa3", handle_fa3)
|
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
|
||||||
BackendRegistry.register("flashmla", handle_flashmla)
|
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
|
||||||
BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
|
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
||||||
BackendRegistry.register("fa4", handle_fa4)
|
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
||||||
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
|
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
||||||
BackendRegistry.register("aiter", handle_aiter)
|
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
|
||||||
BackendRegistry.register("triton", handle_triton)
|
AttentionBackendRegistry.register("triton", handle_attention_triton)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||||
|
|||||||
Reference in New Issue
Block a user