[CPU] [BF16] Call fused_experts_cpu, weight_packed_linear and bmm_cpu kernel in DeepSeek model (#6641)
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
@@ -30,7 +30,12 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
_process_weight_after_loading,
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,6 +57,9 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"IPEXAWQLinearMethod",
|
||||
]
|
||||
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
@@ -165,6 +173,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
_process_weight_after_loading(layer, ["weight"])
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -172,6 +184,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if getattr(layer, "use_intel_amx_backend", False):
|
||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||
x, layer.weight, bias, True # is_vnni
|
||||
)
|
||||
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
|
||||
|
||||
@@ -442,11 +442,20 @@ class LogitsProcessor(nn.Module):
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
if hasattr(lm_head, "weight"):
|
||||
logits = torch.matmul(
|
||||
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
||||
)
|
||||
if getattr(lm_head, "use_intel_amx_backend", False):
|
||||
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
hidden_states.to(lm_head.weight.dtype),
|
||||
lm_head.weight,
|
||||
None, # bias
|
||||
True, # is_vnni
|
||||
)
|
||||
else:
|
||||
logits = torch.matmul(
|
||||
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
||||
)
|
||||
else:
|
||||
# GGUF models
|
||||
# TODO: use weight_packed_linear for GGUF models
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
||||
|
||||
if self.logit_scale is not None:
|
||||
|
||||
@@ -77,8 +77,15 @@ def moe_forward_native(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError()
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -18,7 +18,14 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
_process_weight_after_loading,
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_hip,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
@@ -28,6 +35,8 @@ else:
|
||||
import logging
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _use_aiter:
|
||||
@@ -117,6 +126,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pack weight for get better performance on CPU
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||
|
||||
return
|
||||
|
||||
def apply(
|
||||
@@ -248,19 +262,64 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
)
|
||||
assert activation == "silu", f"activation = {activation} is not supported."
|
||||
|
||||
if (
|
||||
getattr(layer, "use_intel_amx_backend", False)
|
||||
and not apply_router_weight_on_input
|
||||
):
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights.to(
|
||||
torch.float
|
||||
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
||||
topk_ids,
|
||||
True, # inplace
|
||||
False, # use_int8_w8a8
|
||||
False, # use_fp8_w8a16
|
||||
None, # w1_scale
|
||||
None, # w2_scale
|
||||
None, # block_size
|
||||
None, # a1_scale
|
||||
None, # a2_scale
|
||||
True, # is_vnni
|
||||
)
|
||||
else:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
@@ -20,10 +20,18 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
method_has_implemented_embedding,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
PackWeightMethod,
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
@@ -549,6 +557,11 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
|
||||
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
|
||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
|
||||
|
||||
@@ -93,6 +93,7 @@ from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
DeepEPMode,
|
||||
LazyValue,
|
||||
PackWeightMethod,
|
||||
add_prefix,
|
||||
bind_or_assign,
|
||||
cpu_has_amx_support,
|
||||
@@ -144,6 +145,9 @@ class AttnForwardMethod(IntEnum):
|
||||
# Use MLA but with fused RoPE
|
||||
MLA_FUSED_ROPE = auto()
|
||||
|
||||
# Use MLA with fused RoPE kernel for CPU
|
||||
MLA_FUSED_ROPE_CPU = auto()
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -212,8 +216,18 @@ class MoEGate(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.e_score_correction_bias = None
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if getattr(self, "use_intel_amx_backend", False):
|
||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||
hidden_states,
|
||||
self.weight,
|
||||
None, # bias
|
||||
True, # is_vnni
|
||||
)
|
||||
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
return logits
|
||||
|
||||
@@ -778,6 +792,37 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
"SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
|
||||
)
|
||||
|
||||
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
|
||||
# which requires self.w_kc and self.w_vc to be packed.
|
||||
# If not, we will use torch.bmm and weight shouldn't be packed in this case
|
||||
if (
|
||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
and _is_cpu
|
||||
and _is_cpu_amx_available
|
||||
):
|
||||
self.quant_method = PackWeightMethod(
|
||||
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
|
||||
)
|
||||
|
||||
self.qkv_proj_with_rope_is_int8 = (
|
||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
|
||||
)
|
||||
self.qkv_proj_with_rope_is_fp8 = (
|
||||
hasattr(self, "fused_qkv_a_proj_with_mqa")
|
||||
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
self.weight_block_size = None
|
||||
if self.qkv_proj_with_rope_is_fp8:
|
||||
assert (
|
||||
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
||||
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
||||
)
|
||||
self.weight_block_size = (
|
||||
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
||||
)
|
||||
|
||||
def dispatch_attn_forward_method(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> AttnForwardMethod:
|
||||
@@ -791,7 +836,12 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
|
||||
self, "use_intel_amx_backend", False
|
||||
):
|
||||
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
if self.attention_backend == "flashinfer":
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
@@ -905,6 +955,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
||||
inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return None, attn_forward_method, forward_batch, inner_state
|
||||
@@ -924,6 +978,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return self.forward_absorb_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
||||
return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1241,6 +1297,57 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
zero_allocator,
|
||||
)
|
||||
|
||||
def forward_absorb_fused_mla_rope_cpu_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
assert self.q_lora_rank is not None and getattr(
|
||||
self, "use_intel_amx_backend", False
|
||||
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
||||
|
||||
q_input, k_input, v_input = (
|
||||
torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
self.fused_qkv_a_proj_with_mqa.weight,
|
||||
self.q_b_proj.weight,
|
||||
self.w_kc,
|
||||
self.q_a_layernorm.weight,
|
||||
self.kv_a_layernorm.weight,
|
||||
positions,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
self.kv_a_layernorm.variance_epsilon,
|
||||
self.qkv_proj_with_rope_is_int8,
|
||||
self.qkv_proj_with_rope_is_fp8,
|
||||
(
|
||||
self.fused_qkv_a_proj_with_mqa.weight_scale
|
||||
if self.qkv_proj_with_rope_is_int8
|
||||
else (
|
||||
self.fused_qkv_a_proj_with_mqa.weight_scale_inv
|
||||
if self.qkv_proj_with_rope_is_fp8
|
||||
else None
|
||||
)
|
||||
),
|
||||
(
|
||||
self.q_b_proj.weight_scale
|
||||
if self.qkv_proj_with_rope_is_int8
|
||||
else (
|
||||
self.q_b_proj.weight_scale_inv
|
||||
if self.qkv_proj_with_rope_is_fp8
|
||||
else None
|
||||
)
|
||||
),
|
||||
True, # is_vnni
|
||||
self.weight_block_size,
|
||||
self.q_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
)
|
||||
)
|
||||
return (q_input, k_input, v_input, forward_batch, zero_allocator)
|
||||
|
||||
def forward_absorb_fused_mla_rope_core(
|
||||
self,
|
||||
q_input,
|
||||
@@ -1314,6 +1421,43 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward_absorb_fused_mla_rope_cpu_core(
|
||||
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
||||
):
|
||||
assert self.q_lora_rank is not None and getattr(
|
||||
self, "use_intel_amx_backend", False
|
||||
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
||||
|
||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
# [Note] Align shapes of bmm inputs.
|
||||
# Shapes of inputs:
|
||||
# q_nope: [M, B, K]
|
||||
# original self.w_kc: [B, K, N]
|
||||
# current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]
|
||||
|
||||
# Shapes of inputs to sgl_kernel.cpu.bmm:
|
||||
# out: [B, M, N]
|
||||
# mat1: [B, M, K]
|
||||
# mat2: [B, N, K]
|
||||
B = self.w_vc.size(0)
|
||||
N = self.w_vc.size(1)
|
||||
M = attn_output.size(0)
|
||||
output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
|
||||
attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
|
||||
torch.ops.sgl_kernel.bmm_cpu(
|
||||
attn_bmm_output,
|
||||
attn_output.transpose(0, 1),
|
||||
self.w_vc,
|
||||
True, # is_vnni
|
||||
None, # scale
|
||||
)
|
||||
attn_output = output
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
def _chunked_prefix_attn_mha(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -2457,6 +2457,77 @@ def cpu_has_amx_support():
|
||||
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
||||
|
||||
|
||||
def prepack_weight_if_needed(weight):
|
||||
if weight.device != torch.device("cpu"):
|
||||
return weight
|
||||
if not cpu_has_amx_support():
|
||||
return weight
|
||||
|
||||
return torch.ops.sgl_kernel.convert_weight_packed(weight)
|
||||
|
||||
|
||||
# TODO: currently gemm kernel has the below requirements:
|
||||
# OC % TILE_N == 0, where TILE_N = 16
|
||||
# IC % TILE_K == 0, where TILE_K = 32
|
||||
def dim_is_supported(weight):
|
||||
return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0
|
||||
|
||||
|
||||
def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None:
|
||||
# Pack weight for get better performance on CPU
|
||||
devices = {getattr(module, weight_name).device for weight_name in weight_names}
|
||||
assert len(devices) == 1, f"Expects all weights to be on the same device"
|
||||
device = devices.pop()
|
||||
|
||||
if transpose_dims:
|
||||
assert len(weight_names) == len(
|
||||
transpose_dims
|
||||
), "len(weight_names) should be equal to len(transpose_dims)"
|
||||
|
||||
for i, weight_name in enumerate(weight_names):
|
||||
weight_tensor = getattr(module, weight_name)
|
||||
|
||||
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
|
||||
if not dim_is_supported(weight_tensor):
|
||||
logger.warning(
|
||||
f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
|
||||
f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
|
||||
f"{module} won't use intel amx backend."
|
||||
)
|
||||
module.use_intel_amx_backend = False
|
||||
return
|
||||
|
||||
if transpose_dims and transpose_dims[i]:
|
||||
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
|
||||
|
||||
packed_weight = torch.nn.Parameter(
|
||||
prepack_weight_if_needed(weight_tensor),
|
||||
requires_grad=False,
|
||||
)
|
||||
packed_weight.__dict__ = weight_tensor.__dict__
|
||||
setattr(module, weight_name, packed_weight)
|
||||
|
||||
module.use_intel_amx_backend = (
|
||||
device == torch.device("cpu") and cpu_has_amx_support()
|
||||
)
|
||||
|
||||
if (
|
||||
module.use_intel_amx_backend
|
||||
and hasattr(module, "bias")
|
||||
and module.bias is not None
|
||||
):
|
||||
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
|
||||
|
||||
|
||||
class PackWeightMethod:
|
||||
def __init__(self, weight_names, transpose_dims=None):
|
||||
self.weight_names = weight_names
|
||||
self.transpose_dims = transpose_dims
|
||||
|
||||
def process_weights_after_loading(self, module) -> None:
|
||||
_process_weight_after_loading(module, self.weight_names, self.transpose_dims)
|
||||
|
||||
|
||||
class LazyValue:
|
||||
def __init__(self, creator: Callable):
|
||||
self._creator = creator
|
||||
|
||||
Reference in New Issue
Block a user