CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -534,7 +534,25 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group) {
|
||||
int64_t topk_group,
|
||||
int64_t num_fused_shared_experts,
|
||||
std::optional<double> routed_scaling_factor,
|
||||
std::optional<at::Tensor> num_token_non_padded) {
|
||||
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
|
||||
// For now, we just check them as default value.
|
||||
TORCH_CHECK(
|
||||
num_fused_shared_experts == 0,
|
||||
"num_fused_shared_experts must be 0 default value, got: ",
|
||||
num_fused_shared_experts);
|
||||
TORCH_CHECK(
|
||||
!routed_scaling_factor.has_value() || routed_scaling_factor.value() == 1.0f,
|
||||
"routed_scaling_factor must be None or 1.0f default value, got: ",
|
||||
routed_scaling_factor.value());
|
||||
TORCH_CHECK(
|
||||
!num_token_non_padded.has_value(),
|
||||
"num_token_non_padded must be None default value, got: ",
|
||||
num_token_non_padded.value());
|
||||
|
||||
RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
@@ -594,7 +612,21 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group) {
|
||||
int64_t topk_group,
|
||||
int64_t num_fused_shared_experts,
|
||||
std::optional<double> routed_scaling_factor,
|
||||
std::optional<at::Tensor> num_token_non_padded) {
|
||||
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
|
||||
// For now, we just check them as default value.
|
||||
TORCH_CHECK(
|
||||
num_fused_shared_experts == 0,
|
||||
"num_fused_shared_experts must be 0 default value, got: ",
|
||||
num_fused_shared_experts);
|
||||
TORCH_CHECK(
|
||||
!num_token_non_padded.has_value(),
|
||||
"num_token_non_padded must be None default value, got: ",
|
||||
num_token_non_padded.value());
|
||||
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user