Files
xc-llm-ascend/tests/ut/quantization/test_w8a8_dynamic.py

230 lines
11 KiB
Python
Raw Permalink Normal View History

from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod
class TestAscendW8A8FusedMoEMethod(TestBase):
num_experts = 8
hidden_size = 128
intermediate_size = 128
@patch("torch.distributed.get_rank")
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_mc2_group")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ascend_config")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ep_group")
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
mock_get_mc2_group, mock_get_rank):
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
with patch(
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
'vllm_ascend.quantization.methods.w8a8_dynamic.get_current_vllm_config'
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
) as mock_get_current_vllm_config:
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(
quant_description={"group_size": 256})
mock_vllm_config.scheduler_config = Mock(
max_num_batched_tokens=2048,
max_model_len=2048,
enable_chunked_prefill=False)
mock_get_current_vllm_config.return_value = mock_vllm_config
mock_ep_group = Mock()
mock_get_ep_group.return_value = mock_ep_group
mock_ascend_config = Mock()
mock_ascend_config.enable_chunked_prefill = False
mock_ascend_config.multistream_overlap_gate = False
mock_ascend_config.eplb_config = Mock(dynamic_eplb=False)
[feat]: oproj tensor parallelism in pure DP and graph-mode scenarios. (#2167) ### What this PR does / why we need it? This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario. In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing. performance data: <img width="1442" height="442" alt="image" src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d" /> ### Does this PR introduce _any_ user-facing change? This PR introduces one new config in `additional_config`. | Name | Effect | Required | Type | Constraints | | :---------------------------- | :--------------------------------------- | :------- | :--- | :----------------- | | oproj_tensor_parallel_size | Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. | No | int | default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value. | example `--additional_config={"oproj_tensor_parallel_size": 8}` ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/eddaafc1c77b0690194cbd1b73747d572793838c --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: zzh <zzh_201018@outlook.com>
2025-09-07 10:31:32 +08:00
mock_get_ascend_config.return_value = mock_ascend_config
mock_mc2_group = Mock(device_group=0)
mock_get_mc2_group.return_value = mock_mc2_group
mock_rank = Mock()
mock_get_rank.return_value = mock_rank
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
def test_get_weight(self):
param_dict = self.quant_method.get_weight(self.num_experts,
self.intermediate_size,
self.hidden_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(
param_dict["w13_weight"].shape,
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
def test_get_dynamic_quant_param(self):
param_dict = self.quant_method.get_dynamic_quant_param(
self.num_experts, self.intermediate_size, self.hidden_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.num_experts, 2 * self.intermediate_size, 1))
def build_layer(self):
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.empty(
self.num_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(
self.num_experts,
self.hidden_size,
self.intermediate_size,
dtype=torch.int8),
requires_grad=False)
w13_weight_scale = torch.zeros(
(self.num_experts, 2 * self.intermediate_size, 1),
dtype=torch.float32)
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
requires_grad=False)
w13_weight_offset = torch.zeros(
(self.num_experts, 2 * self.intermediate_size, 1),
dtype=torch.float32)
layer.w13_weight_offset = torch.nn.Parameter(w13_weight_offset,
requires_grad=False)
w2_weight_scale = torch.zeros((self.num_experts, self.hidden_size, 1),
dtype=torch.float32)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
w2_weight_offset = torch.zeros((self.num_experts, self.hidden_size, 1),
dtype=torch.float32)
layer.w2_weight_offset = torch.nn.Parameter(w2_weight_offset,
requires_grad=False)
return layer
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_npu_format_cast):
def func_by_args(weight, num_format):
return weight
mock_npu_format_cast.side_effect = func_by_args
new_layer = self.build_layer()
self.quant_method.process_weights_after_loading(new_layer)
mock_npu_format_cast.assert_called()
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx):
tokens = 4
hidden_size = self.hidden_size
layer = torch.nn.Module()
layer.w13_weight = torch.randint(
-8,
8,
(self.num_experts, 2 * self.intermediate_size, hidden_size),
dtype=torch.int8,
)
layer.w2_weight = torch.randint(
-8,
8,
(self.num_experts, hidden_size, self.intermediate_size),
dtype=torch.int8,
)
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
mock_select_experts.return_value = (topk_weights, topk_ids)
mock_comm = Mock()
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
mock_extra_ctx.moe_comm_method = mock_comm
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
self.quant_method.multistream_overlap_gate = False
self.quant_method.in_dtype = torch.float32
self.quant_method.apply(
layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=self.num_experts,
activation="gelu",
apply_router_weight_on_input=True,
mc2_mask=mc2_mask,
pertoken_scale=pertoken_scale,
)
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
self.assertEqual(fused_experts_input.activation, "gelu")
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
self.assertIs(fused_experts_input.topk_weights, topk_weights)
self.assertIs(fused_experts_input.topk_ids, topk_ids)
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_flash_common3_context")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
def test_apply_overlap_gate_uses_fc3_context(
self,
mock_select_experts,
mock_extra_ctx,
mock_get_flash_common3_context,
):
tokens = 4
hidden_size = self.hidden_size
layer = torch.nn.Module()
layer.w13_weight = torch.randint(
-8,
8,
(self.num_experts, 2 * self.intermediate_size, hidden_size),
dtype=torch.int8,
)
layer.w2_weight = torch.randint(
-8,
8,
(self.num_experts, hidden_size, self.intermediate_size),
dtype=torch.int8,
)
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
self.quant_method.multistream_overlap_gate = True
self.quant_method.in_dtype = torch.float32
mock_get_flash_common3_context.return_value = Mock(topk_weights=topk_weights, topk_ids=topk_ids)
mock_comm = Mock()
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
mock_extra_ctx.moe_comm_method = mock_comm
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
self.quant_method.apply(
layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=self.num_experts,
activation="gelu",
apply_router_weight_on_input=True,
mc2_mask=mc2_mask,
pertoken_scale=pertoken_scale,
)
mock_select_experts.assert_not_called()
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
self.assertEqual(fused_experts_input.activation, "gelu")
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
self.assertIs(fused_experts_input.topk_weights, topk_weights)
self.assertIs(fused_experts_input.topk_ids, topk_ids)