[Perf]set moe w2_weight default to be nz (#2842)
### What this PR does / why we need it?
This PR sets the default format of GMM w2_weight in w8a8_dynamic to be
NZ to improve performance.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: main
- vLLM main:
e40827280b
---------
Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -23,7 +23,6 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
|||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
@@ -103,8 +102,9 @@ class AscendW8A8DynamicLinearMethod:
|
|||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
if self.transpose_weight:
|
if self.transpose_weight:
|
||||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||||
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
# cast quantized weight tensors in NZ format for higher inference speed
|
||||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
||||||
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||||
@@ -275,8 +275,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||||
1, 2).contiguous()
|
1, 2).contiguous()
|
||||||
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||||
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
|
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||||
layer.w13_weight_scale.data.shape[0], -1)
|
layer.w13_weight_scale.data.shape[0], -1)
|
||||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch_npu
|
|||||||
from vllm.distributed import GroupCoordinator, get_ep_group
|
from vllm.distributed import GroupCoordinator, get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
@@ -1021,8 +1020,7 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
1, 2).contiguous()
|
1, 2).contiguous()
|
||||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||||
1, 2).contiguous()
|
1, 2).contiguous()
|
||||||
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
|
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||||
layer.w13_weight_scale.data.shape[0], -1)
|
layer.w13_weight_scale.data.shape[0], -1)
|
||||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||||
|
|||||||
Reference in New Issue
Block a user