[CPU] add optimizations for INT8 and FP8 DeepSeek (#6769)
Co-authored-by: Zheng, Beilei <beilei.zheng@intel.com>
This commit is contained in:
@@ -291,7 +291,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
torch.float
|
torch.float
|
||||||
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
||||||
topk_ids,
|
topk_ids,
|
||||||
True, # inplace
|
False, # inplace # See [Note] inplace should be False in fused_experts.
|
||||||
False, # use_int8_w8a8
|
False, # use_int8_w8a8
|
||||||
False, # use_fp8_w8a16
|
False, # use_fp8_w8a16
|
||||||
None, # w1_scale
|
None, # w1_scale
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
_process_weight_after_loading,
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
@@ -330,6 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
|
elif _is_cpu:
|
||||||
|
assert (
|
||||||
|
_is_cpu_amx_available
|
||||||
|
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
||||||
|
_process_weight_after_loading(layer, ["weight"])
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||||
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
@@ -426,6 +433,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
if getattr(layer, "use_intel_amx_backend", False):
|
||||||
|
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
||||||
|
x,
|
||||||
|
layer.weight,
|
||||||
|
layer.weight_scale_inv,
|
||||||
|
self.quant_config.weight_block_size,
|
||||||
|
bias,
|
||||||
|
x.dtype,
|
||||||
|
True, # is_vnni
|
||||||
|
)
|
||||||
|
|
||||||
return self.w8a8_block_fp8_linear(
|
return self.w8a8_block_fp8_linear(
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
@@ -746,6 +764,13 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_weight.data = shuffle_weight(
|
layer.w2_weight.data = shuffle_weight(
|
||||||
layer.w2_weight.contiguous(), (16, 16)
|
layer.w2_weight.contiguous(), (16, 16)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _is_cpu:
|
||||||
|
assert (
|
||||||
|
_is_cpu_amx_available
|
||||||
|
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
|
||||||
|
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||||
@@ -971,6 +996,24 @@ class Fp8MoEMethod:
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(layer, "use_intel_amx_backend", False):
|
||||||
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
False, # inplace See [Note] inplace should be False in fused_experts.
|
||||||
|
False, # use_int8_w8a8
|
||||||
|
True, # use_fp8_w8a16
|
||||||
|
layer.w13_weight_scale_inv, # w1_scale
|
||||||
|
layer.w2_weight_scale_inv, # w2_scale
|
||||||
|
self.quant_config.weight_block_size, # block_size
|
||||||
|
None, # a1_scale
|
||||||
|
None, # a2_scale
|
||||||
|
True, # is_vnni
|
||||||
|
)
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
ret = self.maybe_apply_hip_fused_experts(
|
ret = self.maybe_apply_hip_fused_experts(
|
||||||
layer,
|
layer,
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
capability_tuple = get_device_capability()
|
capability_tuple = get_device_capability()
|
||||||
device_capability = (
|
device_capability = (
|
||||||
-1
|
-1
|
||||||
if capability_tuple is None
|
if all(capability is None for capability in capability_tuple)
|
||||||
else capability_tuple[0] * 10 + capability_tuple[1]
|
else capability_tuple[0] * 10 + capability_tuple[1]
|
||||||
)
|
)
|
||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
|
|||||||
@@ -11,9 +11,17 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
from sglang.srt.utils import is_cuda, set_weight_attrs
|
from sglang.srt.utils import (
|
||||||
|
_process_weight_after_loading,
|
||||||
|
cpu_has_amx_support,
|
||||||
|
is_cpu,
|
||||||
|
is_cuda,
|
||||||
|
set_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
|
_is_cpu = is_cpu()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import int8_scaled_mm
|
from sgl_kernel import int8_scaled_mm
|
||||||
|
|
||||||
@@ -72,6 +80,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
self.quantization_config = quantization_config
|
self.quantization_config = quantization_config
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if _is_cpu:
|
||||||
|
assert (
|
||||||
|
_is_cpu_amx_available
|
||||||
|
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
||||||
|
_process_weight_after_loading(layer, ["weight"])
|
||||||
|
return
|
||||||
|
|
||||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||||
|
|
||||||
@@ -112,6 +127,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
if getattr(layer, "use_intel_amx_backend", False):
|
||||||
|
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
||||||
|
x,
|
||||||
|
layer.weight,
|
||||||
|
layer.weight_scale,
|
||||||
|
bias,
|
||||||
|
x.dtype,
|
||||||
|
True, # is_vnni
|
||||||
|
)
|
||||||
|
|
||||||
x_q, x_scale = per_token_quant_int8(x)
|
x_q, x_scale = per_token_quant_int8(x)
|
||||||
|
|
||||||
return int8_scaled_mm(
|
return int8_scaled_mm(
|
||||||
@@ -206,6 +231,13 @@ class W8A8Int8MoEMethod:
|
|||||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if _is_cpu:
|
||||||
|
assert (
|
||||||
|
_is_cpu_amx_available
|
||||||
|
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
||||||
|
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||||
|
return
|
||||||
|
|
||||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||||
layer.w13_weight_scale = Parameter(
|
layer.w13_weight_scale = Parameter(
|
||||||
@@ -252,6 +284,24 @@ class W8A8Int8MoEMethod:
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if getattr(layer, "use_intel_amx_backend", False):
|
||||||
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
False, # inplace See [Note] inplace should be False in fused_experts.
|
||||||
|
True, # use_int8_w8a8
|
||||||
|
False, # use_fp8_w8a16
|
||||||
|
layer.w13_weight_scale, # w1_scale
|
||||||
|
layer.w2_weight_scale, # w2_scale
|
||||||
|
None, # block_size
|
||||||
|
layer.w13_input_scale, # a1_scale
|
||||||
|
layer.w2_input_scale, # a2_scale
|
||||||
|
True, # is_vnni
|
||||||
|
)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@@ -300,6 +300,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.shared_experts_is_int8 = False
|
||||||
|
self.shared_experts_is_fp8 = False
|
||||||
|
self.shared_experts_weight_block_size = None
|
||||||
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
||||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||||
# disable tp for shared experts when enable deepep moe
|
# disable tp for shared experts when enable deepep moe
|
||||||
@@ -316,6 +319,20 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
else {}
|
else {}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.shared_experts_is_int8 = (
|
||||||
|
self.shared_experts.gate_up_proj.weight.dtype == torch.int8
|
||||||
|
)
|
||||||
|
self.shared_experts_is_fp8 = (
|
||||||
|
self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
if self.shared_experts_is_fp8:
|
||||||
|
assert (
|
||||||
|
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||||
|
== self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
|
||||||
|
)
|
||||||
|
self.shared_experts_weight_block_size = (
|
||||||
|
self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
|
||||||
|
)
|
||||||
|
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
@@ -394,6 +411,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
if hasattr(self, "shared_experts") and getattr(
|
||||||
|
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False
|
||||||
|
):
|
||||||
|
return self.forward_cpu(hidden_states)
|
||||||
|
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
@@ -409,6 +431,59 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
fused_experts_out = self.experts(
|
||||||
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
assert getattr(
|
||||||
|
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False
|
||||||
|
) == getattr(self.shared_experts.down_proj, "use_intel_amx_backend", False)
|
||||||
|
# [Note] inplace should be False in fused_experts.
|
||||||
|
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
|
||||||
|
# While hidden_states is still needed in shared_expert.
|
||||||
|
final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||||
|
hidden_states,
|
||||||
|
self.shared_experts.gate_up_proj.weight,
|
||||||
|
self.shared_experts.down_proj.weight,
|
||||||
|
fused_experts_out,
|
||||||
|
self.routed_scaling_factor,
|
||||||
|
True, # inplace
|
||||||
|
self.shared_experts_is_int8, # use_int8_w8a8
|
||||||
|
self.shared_experts_is_fp8, # use_fp8_w8a16
|
||||||
|
(
|
||||||
|
self.shared_experts.gate_up_proj.weight_scale
|
||||||
|
if self.shared_experts_is_int8
|
||||||
|
else (
|
||||||
|
self.shared_experts.gate_up_proj.weight_scale_inv
|
||||||
|
if self.shared_experts_is_fp8
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
), # w1_scale
|
||||||
|
(
|
||||||
|
self.shared_experts.down_proj.weight_scale
|
||||||
|
if self.shared_experts_is_int8
|
||||||
|
else (
|
||||||
|
self.shared_experts.down_proj.weight_scale_inv
|
||||||
|
if self.shared_experts_is_fp8
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
), # w2_scale
|
||||||
|
(
|
||||||
|
self.shared_experts_weight_block_size
|
||||||
|
if self.shared_experts_is_fp8
|
||||||
|
else None
|
||||||
|
), # block_size
|
||||||
|
None, # a1_scale
|
||||||
|
None, # a2_scale
|
||||||
|
True, # is_vnni
|
||||||
|
)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_deepep(
|
def forward_deepep(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -2107,6 +2182,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
self_attn.w_scale *= 2.0
|
self_attn.w_scale *= 2.0
|
||||||
|
# TODO: remove this after adding FP8 support in bmm cpu kernel
|
||||||
|
if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
|
||||||
|
self_attn.w_kc = (
|
||||||
|
self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
|
||||||
|
)
|
||||||
|
self_attn.w_vc = (
|
||||||
|
self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
|
||||||
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user