CPU: map changes from developing branch in sgl-kernel (#6833)

Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
YanbingJiang
2025-06-10 16:08:15 +08:00
committed by GitHub
parent 81372f3bef
commit fcde67b016
20 changed files with 1321 additions and 321 deletions

View File

@@ -534,7 +534,25 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group) {
int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded) {
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
// For now, we just check them as default value.
TORCH_CHECK(
num_fused_shared_experts == 0,
"num_fused_shared_experts must be 0 default value, got: ",
num_fused_shared_experts);
TORCH_CHECK(
!routed_scaling_factor.has_value() || routed_scaling_factor.value() == 1.0f,
"routed_scaling_factor must be None or 1.0f default value, got: ",
routed_scaling_factor.value());
TORCH_CHECK(
!num_token_non_padded.has_value(),
"num_token_non_padded must be None default value, got: ",
num_token_non_padded.value());
RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
@@ -594,7 +612,21 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group) {
int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded) {
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
// For now, we just check them as default value.
TORCH_CHECK(
num_fused_shared_experts == 0,
"num_fused_shared_experts must be 0 default value, got: ",
num_fused_shared_experts);
TORCH_CHECK(
!num_token_non_padded.has_value(),
"num_token_non_padded must be None default value, got: ",
num_token_non_padded.value());
RECORD_FUNCTION(
"sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));