Fix triton moe error caused by TopK refactor (#8705)
This commit is contained in:
@@ -146,34 +146,3 @@ def triton_kernel_fused_experts(
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def triton_kernel_moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="forward_cuda_triton",
|
||||
op_func=triton_kernel_moe_forward,
|
||||
mutates_args=[],
|
||||
fake_impl=triton_kernel_moe_forward_fake,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user