[CPU]convert topk_weights to fp32 for INT8 and FP8 paths (for llama4) and fix LmHead weight pack (#7818)
This commit is contained in:
@@ -317,9 +317,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights.to(
|
||||
torch.float
|
||||
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False, # inplace # See [Note] inplace should be False in fused_experts.
|
||||
False, # use_int8_w8a8
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
@@ -28,6 +29,8 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
@@ -562,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
|
||||
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
|
||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||
# We only support pack LMHead if it's not quantized.
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||
else:
|
||||
logger.warning("The weight of LmHead is not packed")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
|
||||
Reference in New Issue
Block a user