[kernel] Adapt DispatchGmmCombineDecode operator to parameters of small operators (#4790)
### What this PR does / why we need it?
This PR adapt DispatchGmmCombineDecode operator to parameters of small
operators.
1. This operator no longer requires permuting the weights and scales of
GMM1.
2. This operator no longer requires transposing the weights of GMM2.
Therefore, this operator and the small operator can use the same
parameters (weights and scales), which is beneficial for model
adaptation.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -12,9 +12,7 @@ import torchair
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
torch.manual_seed(42)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
LOG_NAME = "dispatch_gmm_combine_decode_test_logs"
|
||||
@@ -101,7 +99,21 @@ class DecodeMoeOps(torch.nn.Module):
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
self.gmm1_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm1_weight_scale.float(), requires_grad=False)
|
||||
self.gmm2_weight_scale_fp32 = torch.nn.Parameter(
|
||||
gmm2_weight_scale.float(), requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
raise NotImplementedError("To be implemented in subclass")
|
||||
@@ -132,19 +144,6 @@ class SmallOps(DecodeMoeOps):
|
||||
shared_expert_rank_num)
|
||||
self.tp_hcomm_info = ""
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
outputs = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
x=x,
|
||||
@@ -238,41 +237,14 @@ class FusionOp(DecodeMoeOps):
|
||||
ep_world_size, moe_expert_num, global_rank_id,
|
||||
shared_expert_rank_num)
|
||||
|
||||
def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale,
|
||||
gmm2_weight, gmm2_weight_scale):
|
||||
gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\
|
||||
.view(self.local_expert_num, 2, self.moe_intermediate_size // 64, 64, self.token_hidden_size)\
|
||||
.transpose(1,2).contiguous()\
|
||||
.view(self.local_expert_num, self.moe_intermediate_size * 2, self.token_hidden_size)\
|
||||
.transpose(1,2).contiguous()
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.ND)
|
||||
gmm1_weight.add_(0)
|
||||
gmm1_weight = torch_npu.npu_format_cast(gmm1_weight,
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128)
|
||||
gmm2_weight = torch_npu.npu_format_cast(
|
||||
gmm2_weight.transpose(1, 2).contiguous(),
|
||||
torch_npu.Format.FRACTAL_NZ)
|
||||
|
||||
gmm1_weight_scale = gmm1_weight_scale.float()
|
||||
gmm2_weight_scale = gmm2_weight_scale.float()
|
||||
|
||||
self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False)
|
||||
self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale,
|
||||
requires_grad=False)
|
||||
self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False)
|
||||
self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales):
|
||||
output = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=x,
|
||||
expert_ids=expert_ids,
|
||||
gmm1_permuted_weight=self.gmm1_weight,
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale,
|
||||
gmm1_permuted_weight_scale=self.gmm1_weight_scale_fp32,
|
||||
gmm2_weight=self.gmm2_weight,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale,
|
||||
gmm2_weight_scale=self.gmm2_weight_scale_fp32,
|
||||
expert_smooth_scales=smooth_scales,
|
||||
expert_scales=expert_scales,
|
||||
group_ep=self.ep_hcomm_info,
|
||||
@@ -399,6 +371,9 @@ def run_once(local_rank_id,
|
||||
fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused,
|
||||
*parameter).npu() # type: ignore
|
||||
if test_graph:
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
fused_ops = torch.compile(fused_ops, backend=npu_backend)
|
||||
small_op_token_output, small_op_count_output = small_ops(*input_datas)
|
||||
fused_op_token_output, fused_op_count_output = fused_ops(*input_datas)
|
||||
|
||||
Reference in New Issue
Block a user