[CPU]convert topk_weights to fp32 for INT8 and FP8 paths (for llama4) and fix LmHead weight pack (#7818)
This commit is contained in:
@@ -1008,13 +1008,18 @@ at::Tensor fused_experts_cpu(
|
||||
CHECK_DIM(2, topk_ids);
|
||||
|
||||
CHECK_EQ(topk_ids.scalar_type(), at::kInt);
|
||||
CHECK_EQ(topk_weights.scalar_type(), at::kFloat);
|
||||
|
||||
// TODO: support topk_weights to be bf16 or fp16 in the kernel.
|
||||
// The topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bf16/fp16
|
||||
// while the kernel currently only supports it to be float32
|
||||
auto topk_weights_ = topk_weights.to(at::kFloat);
|
||||
CHECK_EQ(topk_weights_.scalar_type(), at::kFloat);
|
||||
|
||||
int64_t M = hidden_states.size(0);
|
||||
int64_t K = hidden_states.size(1);
|
||||
int64_t N = w1.size(1) / 2;
|
||||
int64_t E = w1.size(0);
|
||||
int64_t topk = topk_weights.size(1);
|
||||
int64_t topk = topk_weights_.size(1);
|
||||
|
||||
// we use int32_t compensation for int8 w8a8
|
||||
int64_t packed_K = get_row_size(K, use_int8_w8a8);
|
||||
@@ -1124,7 +1129,7 @@ at::Tensor fused_experts_cpu(
|
||||
packed_w2.data_ptr<int8_t>(),
|
||||
w1s.data_ptr<float>(),
|
||||
w2s.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_weights_.data_ptr<float>(),
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
offsets,
|
||||
@@ -1157,7 +1162,7 @@ at::Tensor fused_experts_cpu(
|
||||
w2s.data_ptr<float>(),
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_weights_.data_ptr<float>(),
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
offsets,
|
||||
@@ -1180,7 +1185,7 @@ at::Tensor fused_experts_cpu(
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
packed_w1.data_ptr<scalar_t>(),
|
||||
packed_w2.data_ptr<scalar_t>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_weights_.data_ptr<float>(),
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
offsets,
|
||||
|
||||
Reference in New Issue
Block a user