From 128f16a817283a2931bdfa285cbccdd238ecc35e Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 9 Jul 2025 10:27:24 +0800 Subject: [PATCH] [CPU]convert topk_weights to fp32 for INT8 and FP8 paths (for llama4) and fix LmHead weight pack (#7818) --- .../srt/layers/moe/fused_moe_triton/layer.py | 4 +--- .../sglang/srt/layers/vocab_parallel_embedding.py | 12 +++++++++--- sgl-kernel/csrc/cpu/moe.cpp | 15 ++++++++++----- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 84470456e..c460b2850 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -317,9 +317,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): x, layer.w13_weight, layer.w2_weight, - topk_weights.to( - torch.float - ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32 + topk_weights, topk_ids, False, # inplace # See [Note] inplace should be False in fused_experts. False, # use_int8_w8a8 diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index d7056f5e0..e7a8ebe11 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -1,5 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py +import logging from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple @@ -28,6 +29,8 @@ DEFAULT_VOCAB_PADDING_SIZE = 64 _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +logger = logging.getLogger(__name__) + class UnquantizedEmbeddingMethod(QuantizeMethodBase): """Unquantized method for embeddings.""" @@ -562,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding): ) self.quant_config = quant_config - # We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight" - if self.quant_config is None and _is_cpu and _is_cpu_amx_available: - self.quant_method = PackWeightMethod(weight_names=["weight"]) + # We only support pack LMHead if it's not quantized. + if _is_cpu and _is_cpu_amx_available: + if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16: + self.quant_method = PackWeightMethod(weight_names=["weight"]) + else: + logger.warning("The weight of LmHead is not packed") if bias: self.bias = Parameter( diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index 2a7d163bb..f755f8f08 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1008,13 +1008,18 @@ at::Tensor fused_experts_cpu( CHECK_DIM(2, topk_ids); CHECK_EQ(topk_ids.scalar_type(), at::kInt); - CHECK_EQ(topk_weights.scalar_type(), at::kFloat); + + // TODO: support topk_weights to be bf16 or fp16 in the kernel. + // The topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bf16/fp16 + // while the kernel currently only supports it to be float32 + auto topk_weights_ = topk_weights.to(at::kFloat); + CHECK_EQ(topk_weights_.scalar_type(), at::kFloat); int64_t M = hidden_states.size(0); int64_t K = hidden_states.size(1); int64_t N = w1.size(1) / 2; int64_t E = w1.size(0); - int64_t topk = topk_weights.size(1); + int64_t topk = topk_weights_.size(1); // we use int32_t compensation for int8 w8a8 int64_t packed_K = get_row_size(K, use_int8_w8a8); @@ -1124,7 +1129,7 @@ at::Tensor fused_experts_cpu( packed_w2.data_ptr(), w1s.data_ptr(), w2s.data_ptr(), - topk_weights.data_ptr(), + topk_weights_.data_ptr(), sorted_ids, expert_ids, offsets, @@ -1157,7 +1162,7 @@ at::Tensor fused_experts_cpu( w2s.data_ptr(), block_size_N, block_size_K, - topk_weights.data_ptr(), + topk_weights_.data_ptr(), sorted_ids, expert_ids, offsets, @@ -1180,7 +1185,7 @@ at::Tensor fused_experts_cpu( hidden_states.data_ptr(), packed_w1.data_ptr(), packed_w2.data_ptr(), - topk_weights.data_ptr(), + topk_weights_.data_ptr(), sorted_ids, expert_ids, offsets,