[CPU]convert topk_weights to fp32 for INT8 and FP8 paths (for llama4) and fix LmHead weight pack (#7818)
This commit is contained in:
@@ -317,9 +317,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights.to(
|
topk_weights,
|
||||||
torch.float
|
|
||||||
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
|
||||||
topk_ids,
|
topk_ids,
|
||||||
False, # inplace # See [Note] inplace should be False in fused_experts.
|
False, # inplace # See [Note] inplace should be False in fused_experts.
|
||||||
False, # use_int8_w8a8
|
False, # use_int8_w8a8
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Sequence, Tuple
|
from typing import List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
@@ -28,6 +29,8 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
|
|||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||||
"""Unquantized method for embeddings."""
|
"""Unquantized method for embeddings."""
|
||||||
@@ -562,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||||||
)
|
)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
|
# We only support pack LMHead if it's not quantized.
|
||||||
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
||||||
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||||
|
else:
|
||||||
|
logger.warning("The weight of LmHead is not packed")
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(
|
self.bias = Parameter(
|
||||||
|
|||||||
@@ -1008,13 +1008,18 @@ at::Tensor fused_experts_cpu(
|
|||||||
CHECK_DIM(2, topk_ids);
|
CHECK_DIM(2, topk_ids);
|
||||||
|
|
||||||
CHECK_EQ(topk_ids.scalar_type(), at::kInt);
|
CHECK_EQ(topk_ids.scalar_type(), at::kInt);
|
||||||
CHECK_EQ(topk_weights.scalar_type(), at::kFloat);
|
|
||||||
|
// TODO: support topk_weights to be bf16 or fp16 in the kernel.
|
||||||
|
// The topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bf16/fp16
|
||||||
|
// while the kernel currently only supports it to be float32
|
||||||
|
auto topk_weights_ = topk_weights.to(at::kFloat);
|
||||||
|
CHECK_EQ(topk_weights_.scalar_type(), at::kFloat);
|
||||||
|
|
||||||
int64_t M = hidden_states.size(0);
|
int64_t M = hidden_states.size(0);
|
||||||
int64_t K = hidden_states.size(1);
|
int64_t K = hidden_states.size(1);
|
||||||
int64_t N = w1.size(1) / 2;
|
int64_t N = w1.size(1) / 2;
|
||||||
int64_t E = w1.size(0);
|
int64_t E = w1.size(0);
|
||||||
int64_t topk = topk_weights.size(1);
|
int64_t topk = topk_weights_.size(1);
|
||||||
|
|
||||||
// we use int32_t compensation for int8 w8a8
|
// we use int32_t compensation for int8 w8a8
|
||||||
int64_t packed_K = get_row_size(K, use_int8_w8a8);
|
int64_t packed_K = get_row_size(K, use_int8_w8a8);
|
||||||
@@ -1124,7 +1129,7 @@ at::Tensor fused_experts_cpu(
|
|||||||
packed_w2.data_ptr<int8_t>(),
|
packed_w2.data_ptr<int8_t>(),
|
||||||
w1s.data_ptr<float>(),
|
w1s.data_ptr<float>(),
|
||||||
w2s.data_ptr<float>(),
|
w2s.data_ptr<float>(),
|
||||||
topk_weights.data_ptr<float>(),
|
topk_weights_.data_ptr<float>(),
|
||||||
sorted_ids,
|
sorted_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
offsets,
|
offsets,
|
||||||
@@ -1157,7 +1162,7 @@ at::Tensor fused_experts_cpu(
|
|||||||
w2s.data_ptr<float>(),
|
w2s.data_ptr<float>(),
|
||||||
block_size_N,
|
block_size_N,
|
||||||
block_size_K,
|
block_size_K,
|
||||||
topk_weights.data_ptr<float>(),
|
topk_weights_.data_ptr<float>(),
|
||||||
sorted_ids,
|
sorted_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
offsets,
|
offsets,
|
||||||
@@ -1180,7 +1185,7 @@ at::Tensor fused_experts_cpu(
|
|||||||
hidden_states.data_ptr<scalar_t>(),
|
hidden_states.data_ptr<scalar_t>(),
|
||||||
packed_w1.data_ptr<scalar_t>(),
|
packed_w1.data_ptr<scalar_t>(),
|
||||||
packed_w2.data_ptr<scalar_t>(),
|
packed_w2.data_ptr<scalar_t>(),
|
||||||
topk_weights.data_ptr<float>(),
|
topk_weights_.data_ptr<float>(),
|
||||||
sorted_ids,
|
sorted_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
offsets,
|
offsets,
|
||||||
|
|||||||
Reference in New Issue
Block a user