[CPU] Optimize FP16 decode_attention_cpu (#10652)
This commit is contained in:
@@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
|
||||
# We only support pack LMHead if it's not quantized.
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
||||
if hasattr(self, "weight") and self.weight.dtype in [
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
]:
|
||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||
|
||||
if bias:
|
||||
|
||||
Reference in New Issue
Block a user