[main] [bugfix] Fix misjudging quantized/unquantized scenarios (#2627)
### What this PR does / why we need it?
In a mixed-precision scenario, quant_config is not None, but MoE needs
to perform unquantized computation; however, quantized computation is
currently being used. Therefore, we put the with_quant logic into
forward, avoid misjudging in mix-precision scenarios.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
98ac0cb32d
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -543,7 +543,6 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
mock_get_forward_context):
|
mock_get_forward_context):
|
||||||
|
|
||||||
mock_forward_context = MagicMock()
|
mock_forward_context = MagicMock()
|
||||||
mock_forward_context.with_quant = True
|
|
||||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||||
mock_get_forward_context.return_value = mock_forward_context
|
mock_get_forward_context.return_value = mock_forward_context
|
||||||
|
|
||||||
@@ -587,10 +586,10 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list_type=1,
|
group_list_type=1,
|
||||||
w1_scale_bias=None,
|
w1_scale_bias=None,
|
||||||
w2_scale_bias=None,
|
w2_scale_bias=None,
|
||||||
topk_scales=None)
|
topk_scales=None,
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
mock_get_forward_context.assert_called()
|
mock_get_forward_context.assert_called()
|
||||||
self.assertTrue(mock_forward_context.with_quant)
|
|
||||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||||
FusedMoEState.MC2)
|
FusedMoEState.MC2)
|
||||||
|
|
||||||
@@ -602,19 +601,15 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
|
|
||||||
self.assertEqual(result.dtype, torch.bfloat16)
|
self.assertEqual(result.dtype, torch.bfloat16)
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
|
||||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||||
@patch('torch_npu.npu_grouped_matmul')
|
@patch('torch_npu.npu_grouped_matmul')
|
||||||
@patch('torch_npu.npu_swiglu')
|
@patch('torch_npu.npu_swiglu')
|
||||||
@patch('torch_npu.npu_dynamic_quant')
|
@patch('torch_npu.npu_dynamic_quant')
|
||||||
def test_unified_apply_mlp_without_quantization(
|
def test_unified_apply_mlp_without_quantization(self,
|
||||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
mock_npu_dynamic_quant,
|
||||||
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
|
mock_npu_swiglu,
|
||||||
|
mock_npu_grouped_matmul,
|
||||||
mock_forward_context = MagicMock()
|
mock_is_310p):
|
||||||
mock_forward_context.with_quant = False
|
|
||||||
mock_get_forward_context.return_value = mock_forward_context
|
|
||||||
|
|
||||||
mock_is_310p.return_value = False
|
mock_is_310p.return_value = False
|
||||||
|
|
||||||
mock_npu_grouped_matmul.side_effect = [[
|
mock_npu_grouped_matmul.side_effect = [[
|
||||||
@@ -639,10 +634,8 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list_type=1,
|
group_list_type=1,
|
||||||
w1_scale_bias=None,
|
w1_scale_bias=None,
|
||||||
w2_scale_bias=None,
|
w2_scale_bias=None,
|
||||||
topk_scales=topk_scales)
|
topk_scales=topk_scales,
|
||||||
|
with_quant=False)
|
||||||
mock_get_forward_context.assert_called()
|
|
||||||
self.assertFalse(mock_forward_context.with_quant)
|
|
||||||
|
|
||||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||||
mock_npu_swiglu.assert_called_once()
|
mock_npu_swiglu.assert_called_once()
|
||||||
@@ -698,10 +691,10 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list_type=1,
|
group_list_type=1,
|
||||||
w1_scale_bias=w1_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w2_scale_bias=w2_scale_bias,
|
||||||
topk_scales=None)
|
topk_scales=None,
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
mock_get_forward_context.assert_called()
|
mock_get_forward_context.assert_called()
|
||||||
self.assertTrue(mock_forward_context.with_quant)
|
|
||||||
|
|
||||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||||
mock_npu_swiglu.assert_called_once()
|
mock_npu_swiglu.assert_called_once()
|
||||||
@@ -710,19 +703,13 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
self.assertEqual(result.shape, hidden_states.shape)
|
self.assertEqual(result.shape, hidden_states.shape)
|
||||||
self.assertEqual(result.dtype, torch.bfloat16)
|
self.assertEqual(result.dtype, torch.bfloat16)
|
||||||
|
|
||||||
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
|
|
||||||
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
@patch('vllm_ascend.ops.fused_moe.is_310p')
|
||||||
@patch('torch_npu.npu_grouped_matmul')
|
@patch('torch_npu.npu_grouped_matmul')
|
||||||
@patch('torch_npu.npu_swiglu')
|
@patch('torch_npu.npu_swiglu')
|
||||||
@patch('torch_npu.npu_dynamic_quant')
|
@patch('torch_npu.npu_dynamic_quant')
|
||||||
def test_unified_apply_mlp_without_quantization_310p(
|
def test_unified_apply_mlp_without_quantization_310p(
|
||||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||||
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
|
mock_npu_grouped_matmul, mock_is_310p):
|
||||||
|
|
||||||
mock_forward_context = MagicMock()
|
|
||||||
mock_forward_context.with_quant = False
|
|
||||||
mock_get_forward_context.return_value = mock_forward_context
|
|
||||||
|
|
||||||
mock_is_310p.return_value = True
|
mock_is_310p.return_value = True
|
||||||
|
|
||||||
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
||||||
@@ -750,10 +737,9 @@ class TestUnifiedApplyMLP(TestBase):
|
|||||||
group_list_type=1,
|
group_list_type=1,
|
||||||
w1_scale_bias=None,
|
w1_scale_bias=None,
|
||||||
w2_scale_bias=None,
|
w2_scale_bias=None,
|
||||||
topk_scales=topk_scales)
|
topk_scales=topk_scales,
|
||||||
|
with_quant=False)
|
||||||
|
|
||||||
mock_get_forward_context.assert_called()
|
|
||||||
self.assertFalse(mock_forward_context.with_quant)
|
|
||||||
mock_is_310p.assert_called_once()
|
mock_is_310p.assert_called_once()
|
||||||
|
|
||||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||||
|
|||||||
@@ -263,7 +263,6 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
|||||||
"max_num_tokens": 100,
|
"max_num_tokens": 100,
|
||||||
"ep_size": 2,
|
"ep_size": 2,
|
||||||
"num_experts": 128,
|
"num_experts": 128,
|
||||||
"with_quant": True,
|
|
||||||
}
|
}
|
||||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||||
|
|
||||||
@@ -460,8 +459,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
|||||||
def test_token_dispatch_with_quant(self):
|
def test_token_dispatch_with_quant(self):
|
||||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||||
num_experts=4,
|
num_experts=4,
|
||||||
num_local_experts=2,
|
num_local_experts=2)
|
||||||
with_quant=True)
|
|
||||||
|
|
||||||
hidden_states = torch.randn(8, 16)
|
hidden_states = torch.randn(8, 16)
|
||||||
topk_weights = torch.rand(8, 4)
|
topk_weights = torch.rand(8, 4)
|
||||||
@@ -476,7 +474,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
row_idx=self.row_idx,
|
row_idx=self.row_idx,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map,
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
self.assertIsNotNone(result["hidden_states"])
|
self.assertIsNotNone(result["hidden_states"])
|
||||||
self.assertIsNotNone(result["group_list"])
|
self.assertIsNotNone(result["group_list"])
|
||||||
@@ -486,8 +485,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
|||||||
def test_token_dispatch_with_quant_no_active_tokens(self):
|
def test_token_dispatch_with_quant_no_active_tokens(self):
|
||||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||||
num_experts=4,
|
num_experts=4,
|
||||||
num_local_experts=2,
|
num_local_experts=2)
|
||||||
with_quant=True)
|
|
||||||
|
|
||||||
self.mock_repeat_interleave.return_value = torch.tensor(
|
self.mock_repeat_interleave.return_value = torch.tensor(
|
||||||
[], dtype=torch.long)
|
[], dtype=torch.long)
|
||||||
@@ -505,7 +503,8 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
row_idx=self.row_idx,
|
row_idx=self.row_idx,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map,
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
self.assertIsNotNone(result["hidden_states"])
|
self.assertIsNotNone(result["hidden_states"])
|
||||||
self.assertIsNotNone(result["group_list"])
|
self.assertIsNotNone(result["group_list"])
|
||||||
|
|||||||
@@ -99,8 +99,6 @@ def set_ascend_forward_context(
|
|||||||
forward_context.fused_moe_state = fused_moe_state
|
forward_context.fused_moe_state = fused_moe_state
|
||||||
forward_context.in_profile_run = in_profile_run
|
forward_context.in_profile_run = in_profile_run
|
||||||
|
|
||||||
with_quant = vllm_config.quant_config is not None
|
|
||||||
forward_context.with_quant = with_quant
|
|
||||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||||
get_token_dispatcher
|
get_token_dispatcher
|
||||||
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
|
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
|
||||||
|
|||||||
@@ -408,19 +408,19 @@ def unquant_apply_mlp(
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def unified_apply_mlp(
|
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
w2_scale: torch.Tensor,
|
group_list: torch.Tensor,
|
||||||
group_list: torch.Tensor,
|
dynamic_scale: torch.Tensor = None,
|
||||||
dynamic_scale: torch.Tensor = None,
|
group_list_type: int = 1,
|
||||||
group_list_type: int = 1,
|
w1_scale_bias: torch.Tensor = None,
|
||||||
w1_scale_bias: torch.Tensor = None,
|
w2_scale_bias: torch.Tensor = None,
|
||||||
w2_scale_bias: torch.Tensor = None,
|
topk_scales: Optional[torch.Tensor] = None,
|
||||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
with_quant: bool = False) -> torch.Tensor:
|
||||||
if get_forward_context().with_quant:
|
if with_quant:
|
||||||
return quant_apply_mlp(hidden_states=hidden_states,
|
return quant_apply_mlp(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
@@ -457,7 +457,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
|||||||
shared_gate_up: Optional[Any] = None,
|
shared_gate_up: Optional[Any] = None,
|
||||||
shared_dequant_scale: Optional[Any] = None,
|
shared_dequant_scale: Optional[Any] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False):
|
apply_router_weight_on_input: bool = False,
|
||||||
|
with_quant: bool = False):
|
||||||
token_dispatcher = get_forward_context().token_dispatcher
|
token_dispatcher = get_forward_context().token_dispatcher
|
||||||
|
|
||||||
results = token_dispatcher.token_dispatch(
|
results = token_dispatcher.token_dispatch(
|
||||||
@@ -472,7 +473,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
|||||||
shared_gate_up=shared_gate_up,
|
shared_gate_up=shared_gate_up,
|
||||||
shared_dequant_scale=shared_dequant_scale,
|
shared_dequant_scale=shared_dequant_scale,
|
||||||
mc2_mask=mc2_mask,
|
mc2_mask=mc2_mask,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
with_quant=with_quant)
|
||||||
|
|
||||||
expert_output = unified_apply_mlp(
|
expert_output = unified_apply_mlp(
|
||||||
hidden_states=results["hidden_states"],
|
hidden_states=results["hidden_states"],
|
||||||
@@ -485,7 +487,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
|||||||
group_list_type=results.get("group_list_type"),
|
group_list_type=results.get("group_list_type"),
|
||||||
w1_scale_bias=w1_scale_bias,
|
w1_scale_bias=w1_scale_bias,
|
||||||
w2_scale_bias=w2_scale_bias,
|
w2_scale_bias=w2_scale_bias,
|
||||||
topk_scales=results.get("topk_scales"))
|
topk_scales=results.get("topk_scales"),
|
||||||
|
with_quant=with_quant)
|
||||||
final_hidden_states = token_dispatcher.token_combine(expert_output)
|
final_hidden_states = token_dispatcher.token_combine(expert_output)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
@@ -577,7 +580,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
shared_experts=shared_experts,
|
shared_experts=shared_experts,
|
||||||
mc2_mask=kwargs.get(
|
mc2_mask=kwargs.get(
|
||||||
"mc2_mask", None))
|
"mc2_mask", None),
|
||||||
|
with_quant=False)
|
||||||
|
|
||||||
|
|
||||||
class AscendFusedMoE(FusedMoE):
|
class AscendFusedMoE(FusedMoE):
|
||||||
@@ -761,7 +765,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
ep_size = (get_ep_group().world_size if
|
ep_size = (get_ep_group().world_size if
|
||||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||||
with_quant = quant_config is not None
|
|
||||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||||
setup_token_dispatchers
|
setup_token_dispatchers
|
||||||
setup_token_dispatchers(
|
setup_token_dispatchers(
|
||||||
@@ -769,8 +772,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
num_experts=self.global_num_experts,
|
num_experts=self.global_num_experts,
|
||||||
num_global_redundant_experts=self.global_redundant_expert_num,
|
num_global_redundant_experts=self.global_redundant_expert_num,
|
||||||
num_local_experts=self.local_num_experts,
|
num_local_experts=self.local_num_experts)
|
||||||
with_quant=with_quant)
|
|
||||||
|
|
||||||
def naive_multicast(self, x: torch.Tensor,
|
def naive_multicast(self, x: torch.Tensor,
|
||||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||||
|
|||||||
@@ -490,7 +490,6 @@ class MoETokenDispatcher(ABC):
|
|||||||
"""
|
"""
|
||||||
self.top_k = kwargs.get("top_k", 0)
|
self.top_k = kwargs.get("top_k", 0)
|
||||||
self.num_experts = kwargs.get("num_experts", 0)
|
self.num_experts = kwargs.get("num_experts", 0)
|
||||||
self.with_quant = kwargs.get("with_quant", False)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ep_group(self):
|
def ep_group(self):
|
||||||
@@ -518,7 +517,8 @@ class MoETokenDispatcher(ABC):
|
|||||||
shared_gate_up: Optional[torch.Tensor] = None,
|
shared_gate_up: Optional[torch.Tensor] = None,
|
||||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False):
|
apply_router_weight_on_input: bool = False,
|
||||||
|
with_quant: bool = False):
|
||||||
raise NotImplementedError("Dispatch function not implemented.")
|
raise NotImplementedError("Dispatch function not implemented.")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -555,6 +555,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
self.topk_weights = None
|
self.topk_weights = None
|
||||||
self.shared_experts = None
|
self.shared_experts = None
|
||||||
self.mc2_mask = None
|
self.mc2_mask = None
|
||||||
|
self.with_quant = False
|
||||||
|
|
||||||
def get_dispatch_mc2_kwargs(
|
def get_dispatch_mc2_kwargs(
|
||||||
self,
|
self,
|
||||||
@@ -615,7 +616,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
|||||||
shared_gate_up: Optional[torch.Tensor] = None,
|
shared_gate_up: Optional[torch.Tensor] = None,
|
||||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False):
|
apply_router_weight_on_input: bool = False,
|
||||||
|
with_quant: bool = False):
|
||||||
|
self.with_quant = with_quant
|
||||||
self.expert_map = expert_map
|
self.expert_map = expert_map
|
||||||
self.topk_ids = topk_ids
|
self.topk_ids = topk_ids
|
||||||
self.topk_weights = topk_weights
|
self.topk_weights = topk_weights
|
||||||
@@ -738,6 +741,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
self.expert_map = None
|
self.expert_map = None
|
||||||
self.topk_weights = None
|
self.topk_weights = None
|
||||||
self.topk_ids = None
|
self.topk_ids = None
|
||||||
|
self.with_quant = False
|
||||||
|
|
||||||
def token_dispatch(self,
|
def token_dispatch(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -751,7 +755,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
|||||||
shared_gate_up: Optional[torch.Tensor] = None,
|
shared_gate_up: Optional[torch.Tensor] = None,
|
||||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False):
|
apply_router_weight_on_input: bool = False,
|
||||||
|
with_quant: bool = False):
|
||||||
|
self.with_quant = with_quant
|
||||||
self.original_shape = hidden_states.shape
|
self.original_shape = hidden_states.shape
|
||||||
|
|
||||||
num_tokens = hidden_states.shape[:-1].numel()
|
num_tokens = hidden_states.shape[:-1].numel()
|
||||||
@@ -922,7 +928,8 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
|
|||||||
shared_gate_up: Optional[torch.Tensor] = None,
|
shared_gate_up: Optional[torch.Tensor] = None,
|
||||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False):
|
apply_router_weight_on_input: bool = False,
|
||||||
|
with_quant: bool = False):
|
||||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||||
if self.apply_router_weight_on_input:
|
if self.apply_router_weight_on_input:
|
||||||
assert (topk_weights.dim() == 2
|
assert (topk_weights.dim() == 2
|
||||||
@@ -980,6 +987,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.with_quant = False
|
||||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||||
self.num_global_redundant_experts = kwargs.get(
|
self.num_global_redundant_experts = kwargs.get(
|
||||||
"num_global_redundant_experts", 0)
|
"num_global_redundant_experts", 0)
|
||||||
@@ -1032,7 +1040,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
|||||||
shared_gate_up: Optional[torch.Tensor] = None,
|
shared_gate_up: Optional[torch.Tensor] = None,
|
||||||
shared_dequant_scale: Optional[torch.Tensor] = None,
|
shared_dequant_scale: Optional[torch.Tensor] = None,
|
||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False):
|
apply_router_weight_on_input: bool = False,
|
||||||
|
with_quant: bool = False):
|
||||||
|
self.with_quant = with_quant
|
||||||
self.hidden_shape = hidden_states.shape
|
self.hidden_shape = hidden_states.shape
|
||||||
self.topk_weights = topk_weights
|
self.topk_weights = topk_weights
|
||||||
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
|
||||||
|
|||||||
@@ -308,7 +308,8 @@ class AscendW4A8DynamicFusedMoEMethod:
|
|||||||
shared_experts=shared_experts,
|
shared_experts=shared_experts,
|
||||||
shared_gate_up=shared_gate_up,
|
shared_gate_up=shared_gate_up,
|
||||||
shared_dequant_scale=shared_dequant_scale,
|
shared_dequant_scale=shared_dequant_scale,
|
||||||
mc2_mask=kwargs.get("mc2_mask", None))
|
mc2_mask=kwargs.get("mc2_mask", None),
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||||
group_num, k, n = weight.shape
|
group_num, k, n = weight.shape
|
||||||
|
|||||||
@@ -406,7 +406,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
shared_experts=shared_experts,
|
shared_experts=shared_experts,
|
||||||
shared_gate_up=shared_gate_up,
|
shared_gate_up=shared_gate_up,
|
||||||
shared_dequant_scale=shared_dequant_scale,
|
shared_dequant_scale=shared_dequant_scale,
|
||||||
mc2_mask=kwargs.get("mc2_mask", None))
|
mc2_mask=kwargs.get("mc2_mask", None),
|
||||||
|
with_quant=True)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer):
|
def process_weights_after_loading(self, layer):
|
||||||
if self.transpose_weight:
|
if self.transpose_weight:
|
||||||
|
|||||||
Reference in New Issue
Block a user