[MOE Refactor] Remove QuantType in prepare_finalize.py (#6534)
### What this PR does / why we need it? To prevent confusion between different QuantType classes, we remove** QuantType in prepare_finalize.py - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from tests.ut.base import TestBase
|
|||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||||
AlltoAllCommImpl,
|
AlltoAllCommImpl,
|
||||||
MC2CommImpl)
|
MC2CommImpl)
|
||||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
||||||
TokenDispatchResult)
|
TokenDispatchResult)
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
|||||||
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
||||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
||||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
||||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
from vllm_ascend.utils import (
|
from vllm_ascend.utils import (
|
||||||
enable_sp,
|
enable_sp,
|
||||||
maybe_trans_nz,
|
maybe_trans_nz,
|
||||||
@@ -235,22 +235,13 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.quant_type = self._get_quant_type()
|
self.quant_type = self._get_quant_type()
|
||||||
|
|
||||||
def _get_quant_type(self) -> QuantType:
|
def _get_quant_type(self) -> QuantType:
|
||||||
quant_method = self.quant_method
|
quant_type = QuantType.NONE
|
||||||
if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None:
|
method = getattr(self.quant_method, "quant_method", None)
|
||||||
return QuantType.NONE
|
|
||||||
|
|
||||||
method = quant_method.quant_method
|
if method is not None:
|
||||||
|
quant_type = getattr(method, "quant_type", QuantType.NONE)
|
||||||
|
|
||||||
if hasattr(method, "quant_type"):
|
return quant_type
|
||||||
from vllm_ascend.quantization.methods.base import QuantType as SchemeQuantType
|
|
||||||
|
|
||||||
scheme_quant_type = method.quant_type
|
|
||||||
if scheme_quant_type == SchemeQuantType.W8A8:
|
|
||||||
return QuantType.W8A8
|
|
||||||
elif scheme_quant_type == SchemeQuantType.W4A8:
|
|
||||||
return QuantType.W4A8
|
|
||||||
|
|
||||||
return QuantType.NONE
|
|
||||||
|
|
||||||
def update_expert_map(self, new_expert_map):
|
def update_expert_map(self, new_expert_map):
|
||||||
self._expert_map = new_expert_map
|
self._expert_map = new_expert_map
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
|||||||
PrepareAndFinalizeWithAll2All,
|
PrepareAndFinalizeWithAll2All,
|
||||||
PrepareAndFinalizeWithAllGather,
|
PrepareAndFinalizeWithAllGather,
|
||||||
PrepareAndFinalizeWithMC2,
|
PrepareAndFinalizeWithMC2,
|
||||||
QuantType,
|
|
||||||
)
|
)
|
||||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||||
MoETokenDispatcher,
|
MoETokenDispatcher,
|
||||||
@@ -38,6 +37,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
|||||||
TokenDispatcherWithAllGather,
|
TokenDispatcherWithAllGather,
|
||||||
TokenDispatcherWithMC2,
|
TokenDispatcherWithMC2,
|
||||||
)
|
)
|
||||||
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
|
|
||||||
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -32,15 +31,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||||
|
from vllm_ascend.quantization.methods.base import QuantType
|
||||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||||
|
|
||||||
|
|
||||||
class QuantType(Enum):
|
|
||||||
NONE = 0
|
|
||||||
W8A8 = 1
|
|
||||||
W4A8 = 2
|
|
||||||
|
|
||||||
|
|
||||||
class PrepareAndFinalize(ABC):
|
class PrepareAndFinalize(ABC):
|
||||||
"""
|
"""
|
||||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||||
|
|||||||
Reference in New Issue
Block a user