[kernel] Adapt DispatchGmmCombineDecode operator to parameters of small operators (#4790)

### What this PR does / why we need it?

This PR adapt DispatchGmmCombineDecode operator to parameters of small
operators.
1. This operator no longer requires permuting the weights and scales of
GMM1.
2. This operator no longer requires transposing the weights of GMM2.

Therefore, this operator and the small operator can use the same
parameters (weights and scales), which is beneficial for model
adaptation.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2025-12-09 16:17:06 +08:00
committed by GitHub
parent 9a885d08d0
commit 9567e5dd8c
5 changed files with 118 additions and 142 deletions

View File

@@ -12,9 +12,7 @@ import torchair
from vllm_ascend.utils import enable_custom_op
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
torch.manual_seed(42)
torch_npu.npu.config.allow_internal_format = True
enable_custom_op()
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
@@ -101,7 +99,21 @@ class DecodeMoeOps(torch.nn.Module):
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
raise NotImplementedError("To be implemented in subclass")
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
gmm1_weight_scale.float(), requires_grad=False)
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
gmm2_weight_scale.float(), requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
raise NotImplementedError("To be implemented in subclass")
@@ -132,19 +144,6 @@ class SmallOps(DecodeMoeOps):
shared_expert_rank_num)
self.tp_hcomm_info = ""
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
x=x,
@@ -238,41 +237,14 @@ class FusionOp(DecodeMoeOps):
ep_world_size, moe_expert_num, global_rank_id,
shared_expert_rank_num)
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
gmm2_weight, gmm2_weight_scale):
gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\
.view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\
.transpose(1,2).contiguous()\
.view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\
.transpose(1,2).contiguous()
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.ND)
gmm1_weight.add_(0)
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128)
gmm2_weight = torch_npu.npu_format_cast(
gmm2_weight.transpose(1, 2).contiguous(),
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = gmm1_weight_scale.float()
gmm2_weight_scale = gmm2_weight_scale.float()
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
requires_grad=False)
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
requires_grad=False)
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=x,
expert_ids=expert_ids,
gmm1_permuted_weight=self.gmm1_weight,
gmm1_permuted_weight_scale=self.gmm1_weight_scale,
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
gmm2_weight=self.gmm2_weight,
gmm2_weight_scale=self.gmm2_weight_scale,
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
expert_smooth_scales=smooth_scales,
expert_scales=expert_scales,
group_ep=self.ep_hcomm_info,
@@ -399,6 +371,9 @@ def run_once(local_rank_id,
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
*parameter).npu() # type: ignore
if test_graph:
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=config)
fused_ops = torch.compile(fused_ops, backend=npu_backend)
small_op_token_output, small_op_count_output = small_ops(*input_datas)
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)