Files
xc-llm-ascend/vllm_ascend/quantization/methods/w8a8_dynamic.py

312 lines
13 KiB
Python
Raw Normal View History

[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
from typing import Any
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
import torch
import torch_npu
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
import vllm_ascend.envs as envs_ascend
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
def scale_from_float_to_int64(scale):
"""Convert float32 scale to int64 representation."""
import numpy as np
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
scale = torch.from_numpy(
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64)
).to(scale.device)
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
return scale
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
"""Linear method for Ascend W8A8_DYNAMIC.
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
This scheme uses dynamic per-token quantization for activations
and per-channel quantization for weights.
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
"""
def __init__(self):
pass
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
return params_dict
def get_perchannel_param(
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
self,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
output_size: int,
params_dtype: torch.dtype,
) -> dict[str, Any]:
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
return params_dict
def apply(
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
self,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
) -> torch.Tensor:
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
need_unsqz = False
if pertoken_scale.dim() == 2:
need_unsqz = True
quantized_x = quantized_x.squeeze(dim=1)
pertoken_scale = pertoken_scale.squeeze(dim=1)
Support multistream of shared experts in FusedMoE (#997) Contains on #1111 for completeness. <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
output = torch_npu.npu_quant_matmul(
quantized_x,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
layer.weight,
layer.weight_scale,
Support multistream of shared experts in FusedMoE (#997) Contains on #1111 for completeness. <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
2025-06-11 09:18:38 +08:00
pertoken_scale=pertoken_scale,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
bias=bias,
output_dtype=x.dtype,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
)
if need_unsqz:
output = output.unsqueeze(dim=1)
return output
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
def process_weights_after_loading(self, layer):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format for higher inference speed
layer.weight.data = maybe_trans_nz(layer.weight.data)
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
layer.weight_scale.data = layer.weight_scale.data.flatten()
performance optimization, usability optimization and API compatibility adjustments for deepseek with npu graph mode (#731) --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> 1. Improve inference speed and usability for deepsek models with NPU graph mode. 2. Modify some codes to adapt to CANN 8.1.RC1.beta1. 3. Add a switch for NPU graph mode and its cache. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> This PR provides an experimental configuration to enable NPU graph mode for Deepseek models. User can set additional_config={'enable_graph_mode': True} to try this feature. Note that this feature currently only supports for V0 engine. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> This patch was tested with the newest torch_npu 2.5.1 (https://pypi.org/project/torch-npu/#files) and CANN 8.1.RC1.beta1 toolkit&nnal&kernels (https://www.hiascend.com/developer/download/community/result?module=cann) released in 25/30 April. Signed-off-by: linfeng-yuan <1102311262@qq.com>
2025-05-01 13:51:42 +08:00
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
layer.weight_offset.data = layer.weight_offset.data.flatten()
[Refactor] Quantization Module Refactor (#5738) ### Summary This PR refactors the `vllm_ascend/quantization` module to improve code organization, maintainability, and extensibility. The refactoring introduces a clear separation of concerns with a registry-based scheme discovery pattern, abstract base classes for quantization schemes, and dedicated wrapper classes. ### Key Changes #### 1. **Modular Directory Structure** | Before | After | |--------|-------| | Flat file structure with mixed responsibilities | Organized into `methods/` subpackage for schemes | | Single `quant_config.py` (600+ lines) | Separate config files: `modelslim_config.py`, `compressed_tensors_config.py` | | `utils.py` with scheme lookup logic | `methods/registry.py` with decorator-based registration | #### 2. **Registry-Based Scheme Discovery** Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a decorator-based registry pattern: ```python # Before: Manual dictionary mapping ASCEND_QUANTIZATION_METHOD_MAP = { "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...}, ... } # After: Decorator-based registration @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): ... ``` #### 3. **Abstract Base Classes** Introduced three abstract base classes in `methods/base.py`: - `AscendLinearScheme` - Base for linear layer quantization - `AscendMoEScheme` - Base for MoE layer quantization - `AscendAttentionScheme` - Base for attention layer quantization #### 4. **Separated Config and Wrapper Classes** - **Config classes** (`AscendModelSlimConfig`, `AscendCompressedTensorsConfig`): Handle config parsing and scheme selection - **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`, etc.): Implement vLLM interfaces and delegate to schemes #### 5. **Cleaner Public API** ```python # New clean module interface from vllm_ascend.quantization import ( AscendModelSlimConfig, AscendCompressedTensorsConfig, ) from vllm_ascend.quantization.methods import get_scheme_class ``` ### Architecture Diagram ```mermaid classDiagram direction TB class QuantizationConfig { <<vLLM Interface>> +get_quant_method() } class AscendModelSlimConfig { +quant_description +get_quant_method() -create_scheme_for_layer() } class AscendCompressedTensorsConfig { +target_scheme_map +get_quant_method() -_get_scheme_from_parts() } class AscendLinearMethod { <<Wrapper>> +quant_method: AscendLinearScheme +create_weights() +apply() } class AscendFusedMoEMethod { <<Wrapper>> +quant_method: AscendMoEScheme +create_weights() +apply() } class AscendLinearScheme { <<Abstract>> +get_weight()* +apply()* +get_pertensor_param() +get_perchannel_param() } class AscendMoEScheme { <<Abstract>> +get_weight()* +get_dynamic_quant_param()* +apply()* } class W8A8DynamicLinear { +get_weight() +apply() } class W8A8DynamicMoE { +get_weight() +apply() } QuantizationConfig <|-- AscendModelSlimConfig QuantizationConfig <|-- AscendCompressedTensorsConfig AscendModelSlimConfig ..> AscendLinearMethod : creates AscendModelSlimConfig ..> AscendFusedMoEMethod : creates AscendCompressedTensorsConfig ..> AscendLinearMethod : creates AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates AscendLinearMethod o-- AscendLinearScheme : delegates to AscendFusedMoEMethod o-- AscendMoEScheme : delegates to AscendLinearScheme <|-- W8A8DynamicLinear AscendMoEScheme <|-- W8A8DynamicMoE ``` ### Scheme Registration Flow ```mermaid sequenceDiagram participant Module as Scheme Module participant Registry as _SCHEME_REGISTRY participant Config as QuantConfig participant Wrapper as Wrapper Class Note over Module: At import time Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear") Registry->>Registry: Store (quant_type, layer_type) -> Class Note over Config: At runtime Config->>Config: Determine quant_type from description Config->>Registry: get_scheme_class(quant_type, layer_type) Registry-->>Config: Return scheme class Config->>Config: scheme = scheme_cls() Config->>Wrapper: Create wrapper with scheme Wrapper-->>Config: Return wrapper instance ``` ### File Changes Summary | Original Files | Refactored Files | |----------------|------------------| | `__init__.py` (empty) | `__init__.py` (exports public API) | | `quant_config.py` | `modelslim_config.py` + `wrappers.py` | | `compressed_tensors/` | `compressed_tensors_config.py` | | `utils.py` | `methods/registry.py` | | `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` | | `w8a8.py` | `methods/w8a8_static.py` | | `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` | | ... | `methods/base.py` (new) | ### Benefits 1. **Extensibility**: Adding new quantization schemes only requires implementing the base class and adding `@register_scheme` decorator 2. **Maintainability**: Clear separation between config parsing, wrapper logic, and scheme implementation 3. **Testability**: Abstract base classes enable easier unit testing and mocking 4. **Discoverability**: Registry pattern makes it easy to list all supported schemes 5. **Reduced Coupling**: Config classes no longer need to know about all scheme implementations ___ - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
@register_scheme("W8A8_DYNAMIC", "moe")
class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
"""FusedMoE method for Ascend W8A8_DYNAMIC."""
# Declare the quantization type for this scheme
quant_type: QuantType = QuantType.W8A8
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
def __init__(self):
self.ep_group = get_ep_group()
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
self.use_aclgraph = (
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager
)
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
[1/N][Refactor] Refactor code to adapt with vllm main (#3612) ### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 1. refactor deepseek to the latest code arch as of https://github.com/vllm-project/vllm/commit/17c540a993af88204ad1b78345c8a865cf58ce44 2. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00
[EPLB]Eplb Config Renaming (#5533) ### What this PR does / why we need it? 1. Rename num_iterations_eplb_update to expert_heat_collection_interval. 2. Rename num_wait_worker_iterations to algorithm_execution_interval. 3. Rename init_redundancy_expert to num_redundant_experts because the variable with the same meaning in vLLM is named this way. 4. Delete gate_eplb because we don't need this feature. 5. Move eplb config into a dict in additional config. 6. Depend on pr5817 ### Does this PR introduce _any_ user-facing change? before this pr: `--additional-config '{"dynamic_eplb":true, "num_iterations_eplb_update": 4000, "num_wait_worker_iterations": 150, "init_redundancy_expert": 16, "expert_map_path": "xxx.json"}'` after this pr: `--additional-config '{"eplb_config":{"dynamic_eplb":true,"expert_heat_collection_interval":4000, "algorithm_execution_interval":150,"num_redundant_experts": 16, "expert_map_path": "xxx.json"}}'` ### How was this patch tested? #### test qwen3-235b eplb num_redundant_experts=16 without pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 83.33 | with pr5817 | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 86.67 | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/45c1ca1ca1ee8fa06df263c8715e8a412ff408d4 Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
2026-01-15 10:26:44 +08:00
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
self.in_dtype = vllm_config.model_config.dtype
[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311) ### What this PR does / why we need it? Unify the loading logic for expert_map and log2phy. 1. The map generated when enabling the redundancy expert is incorrect. The community generation map function only accepts the number of global experts. When we pass in the number of logical experts plus redundant experts, the local expert ID of the last card will index to an expert ID that does not exist. Now we ensure that the index points to a real existing expert ID, and each expert can be accessed. Moreover, when redundant experts are not enabled, the output of our function remains consistent with the community's function. 2. The map we generate is based on the length of the physical expert, but in reality, we only need to use the length of the logical expert. Later on, we will need to pad it accordingly, so we can simply generate a map with the length of the logical [expert.] 3. Unify the initialization logic across different scenarios and simplify the code for fused_moe. **Before refactoring** - map path is not None: expert map: get_rank_placement_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_, maintains the map for all ranks and all layers. - map path is None: expert map: determine_expert_map from '_vllm.laye_r', The function does not support the redundant experts of vllm-ascend. log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The function does not support the redundant experts of vllm-ascend. **Refactoring** eplb_utils.py &nbsp;&nbsp;&nbsp;&nbsp;init_eplb_config &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate placement &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate expert map &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; generate log2phy ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 16 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16] +++++++++++++++++++++++++++++++++++++++++ Improved map: [16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] Expert Mapping Test Generation: ep size: 16, num of experts: 256, num of redundant experts: 0 +++++++++++++++++++++++++++++++++++++++++ Expert Mapping (Non-1 indicates the expert responsible for this rank) for Rank 15: vllm map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] +++++++++++++++++++++++++++++++++++++++ Improved map: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15] dsr1 baselie: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | dsr1 eplb: | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | gsm8k-lite | 7cd45e | accuracy | gen | 100.00 | - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/5fbfa8d9ef15948599631baeb91e8220b2ee9bcc Signed-off-by: shenchuxiaofugui <1311027364@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-29 09:26:14 +08:00
self.supports_eplb = True
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
def get_weight(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
param_dict = {}
param_dict["w13_weight"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.int8
)
param_dict["w2_weight"] = torch.empty(
num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.int8
)
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
return param_dict
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype
)
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
param_dict["w13_weight_offset"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype
)
param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype)
param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype)
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
return param_dict
def apply(
self,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None,
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
**kwargs,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if zero_expert_num == 0 or zero_expert_type is None:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
)
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context()
assert fc3_context is not None
topk_weights = fc3_context.topk_weights
topk_ids = fc3_context.topk_ids
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
assert topk_ids is not None
assert topk_weights is not None
if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
hidden_states=x,
)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
assert topk_weights is not None
topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = _EXTRA_CTX.moe_comm_method
fused_scale_flag = (
_EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2 and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
)
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.fused_w1_scale_list if fused_scale_flag else layer.w13_weight_scale_fp32_list
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
w2 = layer.w2_weight_list
w2_scale = layer.fused_w2_scale_list if fused_scale_flag else layer.w2_weight_scale_list
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
else:
w1 = [layer.w13_weight]
w1_scale = [layer.fused_w1_scale] if fused_scale_flag else [layer.w13_weight_scale_fp32]
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
w2 = [layer.w2_weight]
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask"),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
return final_hidden_states
[quantization] Support w8a8 quantization (#580) ### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
def process_weights_after_loading(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
# TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant`
# can only support weight nz.
layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(torch.float32)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
add `dispatch_gmm_combine` kernel (#3532) ### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-12-04 23:00:59 +08:00
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
if self.dynamic_eplb:
layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)]
layer.w2_weight_list = [weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)]
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
layer.w13_weight_scale_fp32_list = [
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
]
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
layer.fused_w1_scale_list = [
weight.clone()
for weight in layer.fused_w1_scale.view(len(layer.w13_weight_list), -1).data.unbind(dim=0)
]
layer.fused_w2_scale_list = [
weight.clone()
for weight in layer.fused_w2_scale.view(len(layer.w2_weight_list), -1).data.unbind(dim=0)
]
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
del layer.fused_w1_scale
del layer.fused_w2_scale
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
torch.npu.empty_cache()