[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #7) (#6023)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|` vllm_ascend/quantization/compressed_tensors/compressed_tensors.py`|
|` vllm_ascend/quantization/quant_config.py`|
|` vllm_ascend/quantization/utils.py`|
|` vllm_ascend/quantization/w4a16.py`|
|` vllm_ascend/quantization/w4a4_flatquant_dynamic.py`|
|` vllm_ascend/quantization/w4a8_dynamic.py`|
|` vllm_ascend/quantization/w8a16.py`|
|` vllm_ascend/quantization/w8a8.py`|
|` vllm_ascend/quantization/w8a8_dynamic.py`|
|` vllm_ascend/quantization/w8a8_pdmix.py`|
|` vllm_ascend/quantization/w8a8mxfp8.py`|
|` vllm_ascend/sample/rejection_sampler.py`|
|` vllm_ascend/sample/sampler.py`|
|` vllm_ascend/worker/block_table.py`|

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-06 14:56:53 +08:00
committed by GitHub
parent d0bc16859c
commit 99aedaff63
20 changed files with 997 additions and 1307 deletions

View File

@@ -51,10 +51,7 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
# (7)
"vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py",
"vllm_ascend/worker/block_table.py",
# (8)
"vllm_ascend/ops/__init__.py",
"vllm_ascend/ops/activation.py",
@@ -66,6 +63,7 @@ exclude = [
"vllm_ascend/ops/vocab_parallel_embedding.py",
"vllm_ascend/ops/weight_prefetch.py",
"vllm_ascend/spec_decode/**",
# (10)
"vllm_ascend/ops/*linear*.py",
"vllm_ascend/worker/worker.py",
@@ -76,6 +74,7 @@ exclude = [
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py",
# (11)
"vllm_ascend/ops/fused_moe/**",
]

View File

@@ -29,6 +29,7 @@ Public API:
# LLM-Compressor (compressed_tensors) quantization config
from .compressed_tensors_config import AscendCompressedTensorsConfig
# ModelSlim quantization config
from .modelslim_config import AscendModelSlimConfig

View File

@@ -17,23 +17,20 @@
#
"""LLM-Compressor (compressed_tensors) quantization configuration for Ascend."""
from typing import Any, Optional, Union, cast
from typing import Any, Optional, cast
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, QuantizationType
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS, register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
find_matched_target,
is_activation_quantization_format,
should_ignore_layer,
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
@@ -51,14 +48,13 @@ def _remove_quantization_method():
_remove_quantization_method()
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str,
"QuantizationArgs"]]]
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, "QuantizationArgs"] | None]
@register_quantization_config(COMPRESSED_TENSORS_METHOD)
class AscendCompressedTensorsConfig(QuantizationConfig):
"""Config class for LLM-Compressor (compressed_tensors) quantization on Ascend.
This class adapts the compressed_tensors format to work with Ascend's
quantization implementations.
"""
@@ -68,7 +64,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
config: Optional[dict[str, Any]] = None,
config: dict[str, Any] | None = None,
):
super().__init__()
self.ignore = ignore
@@ -86,8 +82,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")
raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.')
@classmethod
def get_config_filenames(cls) -> list[str]:
@@ -100,18 +95,15 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
targeting 'Linear' needs to also match
FusedMoE modules.
"""
if ("Linear" not in self.target_scheme_map
or "FusedMoE" in self.target_scheme_map):
if "Linear" not in self.target_scheme_map or "FusedMoE" in self.target_scheme_map:
return
self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"]
@classmethod
def from_config(cls, config: dict[str,
Any]) -> "AscendCompressedTensorsConfig":
def from_config(cls, config: dict[str, Any]) -> "AscendCompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(
config=config)
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
return cls(
target_scheme_map=target_scheme_map,
@@ -121,10 +113,9 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
)
@classmethod
def _quantization_scheme_map_from_config(
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
def _quantization_scheme_map_from_config(cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""Build target scheme map from config.
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
@@ -138,24 +129,22 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))
target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(quant_config.get("weights"))
target_scheme_map[target]["input_activations"] = None
target_scheme_map[target]["format"] = quant_config.get(
"format")
target_scheme_map[target]["format"] = quant_config.get("format")
format = target_scheme_map[target].get("format")
# If no per-config format defined, use global format in config
act_quant_format = (
is_activation_quantization_format(format)
if format is not None else
is_activation_quantization_format(quant_format))
if format is not None
else is_activation_quantization_format(quant_format)
)
input_activations = quant_config.get("input_activations")
if act_quant_format and input_activations is not None:
target_scheme_map[target]["input_activations"] = (
QuantizationArgs.model_validate(
quant_config.get("input_activations")))
target_scheme_map[target]["input_activations"] = QuantizationArgs.model_validate(
quant_config.get("input_activations")
)
return target_scheme_map
def get_quant_method(
@@ -168,8 +157,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
# Get the scheme for this layer
linear_scheme = self._get_linear_scheme(layer=layer,
layer_name=prefix)
linear_scheme = self._get_linear_scheme(layer=layer, layer_name=prefix)
# Return unquantized method if no scheme found
if linear_scheme is None:
@@ -177,14 +165,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Store scheme on layer for reference (optional, for debugging)
layer.scheme = linear_scheme
logger.info_once(
"Using the vLLM Ascend llmcompressor Quantization now!")
logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!")
return AscendLinearMethod(linear_scheme)
if isinstance(layer, FusedMoE):
# Delayed import to avoid circular import
from vllm_ascend.ops.fused_moe.fused_moe import \
AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
layer_name = prefix + ".0.gate_proj"
@@ -197,24 +183,19 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Store scheme on layer for reference (optional, for debugging)
layer.scheme = moe_scheme
logger.info_once(
"Using the vLLM Ascend llmcompressor Quantization now!")
logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!")
return AscendFusedMoEMethod(moe_scheme, layer.moe_config)
return None
def _get_linear_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> Optional[AscendLinearScheme]:
def _get_linear_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendLinearScheme | None:
"""Get the linear quantization scheme for a layer.
Returns:
An AscendLinearScheme instance, or None if the layer
should use unquantized method.
"""
weight_quant, input_quant, format = self._get_quant_args(
layer, layer_name)
weight_quant, input_quant, format = self._get_quant_args(layer, layer_name)
if weight_quant is None:
return None
@@ -226,12 +207,9 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
)
return cast(AscendLinearScheme, scheme)
def _get_moe_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> Optional[AscendMoEScheme]:
def _get_moe_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendMoEScheme | None:
"""Get the MoE quantization scheme for a layer.
Returns:
An AscendMoEScheme instance, or None if the layer
should use unquantized method.
@@ -239,8 +217,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Add FusedMoE to target scheme map if needed
self._add_fused_moe_to_target_scheme_map()
weight_quant, input_quant, format = self._get_quant_args(
layer, layer_name)
weight_quant, input_quant, format = self._get_quant_args(layer, layer_name)
if weight_quant is None:
return None
@@ -253,13 +230,10 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
return cast(AscendMoEScheme, scheme)
def _get_quant_args(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None
) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"],
Optional[str]]:
self, layer: torch.nn.Module, layer_name: str | None = None
) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"], str | None]:
"""Extract quantization arguments for a layer.
compressed-tensors supports non uniform in the following way:
targets of config_groups: There can be N config_groups which each
@@ -269,7 +243,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
Detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target.
Returns:
A tuple of (weight_quant, input_quant, format). weight_quant is
None if the layer should use unquantized method.
@@ -284,16 +258,16 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
format = scheme_dict.get("format")
if weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod")
logger.warning_once(
"Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod"
)
return weight_quant, input_quant, format
def get_scheme_dict(
self,
layer: torch.nn.Module,
layer_name: str | None = None
self, layer: torch.nn.Module, layer_name: str | None = None
) -> dict[str, QuantizationArgs | str | None] | None:
"""
Extract the QuantizationArgs for a given layer.
@@ -305,9 +279,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
"format": str | None
} | None
"""
if should_ignore_layer(layer_name,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
if should_ignore_layer(layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping):
return None
if self.target_scheme_map:
@@ -328,17 +300,17 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
self,
weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"],
format: Optional[str],
format: str | None,
layer_type: str,
) -> Union[AscendLinearScheme, AscendMoEScheme]:
) -> AscendLinearScheme | AscendMoEScheme:
"""Create the appropriate Ascend scheme based on quantization args and layer type.
Args:
weight_quant: Weight quantization arguments.
input_quant: Input activation quantization arguments.
format: Per-layer format, if defined.
layer_type: Type of layer ("linear" or "moe").
Returns:
An instance of the appropriate Ascend quantization scheme.
"""
@@ -352,7 +324,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
if scheme_cls is None:
raise NotImplementedError(
f"No compressed-tensors compatible scheme was found for "
f"quant_type={quant_type}, layer_type={layer_type}.")
f"quant_type={quant_type}, layer_type={layer_type}."
)
return scheme_cls()
@@ -360,15 +333,15 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
self,
weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"],
format: Optional[str],
format: str | None,
) -> str:
"""Detect the quantization type from quantization arguments.
Args:
weight_quant: Weight quantization arguments.
input_quant: Input activation quantization arguments.
format: Per-layer format, if defined.
Returns:
A string representing the quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
"""
@@ -389,16 +362,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
if self._is_w4a16(weight_quant, input_quant):
return "W4A16"
raise NotImplementedError(
"No compressed-tensors compatible quantization type was found.")
raise NotImplementedError("No compressed-tensors compatible quantization type was found.")
def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs",
input_quant: "QuantizationArgs") -> bool:
def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value
is_tensor = weight_strategy and input_quant.strategy == QuantizationStrategy.TENSOR.value
is_static = not weight_quant.dynamic and not input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
@@ -406,28 +375,24 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and is_symmetric and is_static
def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs",
input_quant: "QuantizationArgs") -> bool:
def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value
is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only symmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and is_symmetric and is_dynamic
def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs) -> bool:
is_4_bits = weight_quant.num_bits == 4
is_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or (weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
weight_strategy = (weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or (
weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
@@ -435,7 +400,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
assert self.quant_description is not None, "quant_description should not be None"
if weight_strategy:
self.quant_description["group_size"] = weight_quant.group_size if weight_quant.group_size else 0
self.quant_description["version"] = "0"
self.quant_description["ascend_quant_method"] = COMPRESSED_TENSORS_METHOD
self.quant_description["weight_strategy"] = str(weight_quant.strategy)
@@ -444,8 +409,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_4_bits and is_8_bits and is_token and is_symmetric and is_dynamic
def _is_w4a16(self, weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"]) -> bool:
def _is_w4a16(self, weight_quant: "QuantizationArgs", input_quant: Optional["QuantizationArgs"]) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
@@ -456,12 +420,11 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
input_quant_none = input_quant is None
is_4_bits = weight_quant.num_bits == 4
is_group = (weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_group = weight_quant.strategy == QuantizationStrategy.GROUP.value
is_static = not weight_quant.dynamic
return input_quant_none and is_4_bits and is_group and is_static
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
self.target_scheme_map)
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)

View File

@@ -16,27 +16,22 @@
# This file is a part of the vllm-ascend project.
#
from typing import Callable, List, Optional
from collections.abc import Callable
import torch
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase, FusedMoeWeightScaleSupported
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearMethodBase,
RowParallelLinear)
from vllm.model_executor.layers.linear import LinearMethodBase, RowParallelLinear
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.distributed.parallel_state import get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group
from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable
from .methods import (AscendAttentionScheme, AscendLinearScheme,
AscendMoEScheme, is_mx_quant_type)
from .methods import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, is_mx_quant_type
class AscendLinearMethod(LinearMethodBase):
@@ -56,7 +51,7 @@ class AscendLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
@@ -65,9 +60,7 @@ class AscendLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)
weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, params_dtype)
# Extract packing information (if present)
packed_dim = weight_dict.pop("_packed_dim", None)
@@ -79,25 +72,20 @@ class AscendLinearMethod(LinearMethodBase):
# Set packing attributes if the weight is packed
if packed_dim is not None and packed_factor is not None:
set_weight_attrs(param, {
"packed_dim": packed_dim,
"packed_factor": packed_factor
})
set_weight_attrs(param, {"packed_dim": packed_dim, "packed_factor": packed_factor})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param,
weight_loader=weight_loader)
param = PerTensorScaleParameter(data=pertensor_param, weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
param.weight_loader = extra_weight_attrs.get("weight_loader")
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
perchannel_dict = self.quant_method.get_perchannel_param(output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
@@ -107,22 +95,22 @@ class AscendLinearMethod(LinearMethodBase):
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj scale_bias shape is [output_size, 16],
# others are [output_size, 1]
layer_type = "row" if isinstance(layer,
RowParallelLinear) else "others"
layer_type = "row" if isinstance(layer, RowParallelLinear) else "others"
pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition,
output_size_per_partition,
params_dtype,
layer_type=layer_type)
input_size_per_partition, output_size_per_partition, params_dtype, layer_type=layer_type
)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
or is_mx_quant_type(self.quant_method):
setattr(param, "input_dim", 1)
if (
"weight_scale_second" in pergroup_name
or "weight_offset_second" in pergroup_name
or is_mx_quant_type(self.quant_method)
):
param.input_dim = 1
param.input_dim = 1
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -133,17 +121,15 @@ class AscendLinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
tp_rank = get_otp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
tp_rank = get_mlp_tp_group().rank_in_group
elif (layer.prefix.find("o_proj") != -1 or
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
if get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size == 1:
elif (layer.prefix.find("o_proj") != -1 or layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
tp_rank = 0
else:
tp_rank = get_flashcomm2_otp_group().rank_in_group
@@ -175,11 +161,19 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
attn_metadata, attn_type, scale, output)
def apply(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache,
attn_metadata,
attn_type,
scale,
output,
) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache, attn_metadata, attn_type, scale, output)
class AscendFusedMoEMethod(FusedMoEMethodBase):
@@ -192,8 +186,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
moe_config: The FusedMoE configuration.
"""
def __init__(self, scheme: AscendMoEScheme,
moe_config: FusedMoEConfig) -> None:
def __init__(self, scheme: AscendMoEScheme, moe_config: FusedMoEConfig) -> None:
super().__init__(moe_config)
self.quant_method = scheme
@@ -207,30 +200,28 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
**extra_weight_attrs,
) -> None:
weight_param = self.quant_method.get_weight(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
num_experts, intermediate_size_per_partition, hidden_size, params_dtype
)
for param_key, param_value in weight_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
] + ["weight_scale", "weight_offset"] if hasattr(
self.quant_method,
"group_size") and self.quant_method.group_size > 0 else []
extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = (
["weight_scale_second", "weight_offset_second", "scale_bias"] + ["weight_scale", "weight_offset"]
if hasattr(self.quant_method, "group_size") and self.quant_method.group_size > 0
else []
)
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
num_experts, intermediate_size_per_partition, hidden_size, params_dtype
)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)
param.quant_method = FusedMoeWeightScaleSupported.GROUP.value
def apply(
self,
@@ -241,25 +232,40 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num=0,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, routed_scaling_factor,
e_score_correction_bias, is_prefill, enable_force_load_balance,
log2phy, global_redundant_expert_num, **kwargs)
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
global_num_experts,
expert_map,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
routed_scaling_factor,
e_score_correction_bias,
is_prefill,
enable_force_load_balance,
log2phy,
global_redundant_expert_num,
**kwargs,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
@@ -276,7 +282,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.
This is essentially the same as AscendLinearMethod, just with a different name
for clarity when used with VocabParallelEmbedding layers.

View File

@@ -21,7 +21,7 @@ Schemes are automatically registered via the @register_scheme decorator.
Usage:
from vllm_ascend.quantization.methods import get_scheme_class
# Get a scheme class by quant_type and layer_type
scheme_cls = get_scheme_class("W8A8_DYNAMIC", "linear")
scheme = scheme_cls()
@@ -30,28 +30,26 @@ Usage:
from typing import Any
# Import base classes
from .base import (AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme,
QuantType)
from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType
# Import registry functions
from .registry import get_scheme_class, register_scheme
# Import all scheme classes for external access
from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
from .w4a8 import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod
from .w4a16 import AscendW4A16FusedMoEMethod
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
from .w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod
from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
AscendW8A8PDMixLinearMethod)
from .w8a8_pdmix import AscendW8A8PDMixFusedMoeMethod, AscendW8A8PDMixLinearMethod
from .w8a8_static import AscendW8A8LinearMethod
from .w8a16 import AscendW8A16LinearMethod
def is_mx_quant_type(instance: Any) -> bool:
"""Checks if the quantization method is a microscaling (MX) type."""
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, )
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod,)
return isinstance(instance, MX_QUANT_TYPES)

View File

@@ -17,14 +17,16 @@
"""Abstract base classes for Ascend quantization schemes."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from typing import Any, Callable, Dict, Optional
from typing import Any
import torch
class QuantType(Enum):
"""Quantization type enum for MoE schemes."""
NONE = 0
W8A8 = 1
W4A8 = 2
@@ -32,84 +34,78 @@ class QuantType(Enum):
class AscendLinearScheme(ABC):
"""Base class for all linear quantization schemes.
Subclasses must implement get_weight() and apply() methods.
Other methods have default implementations that return empty dicts
or do nothing.
"""
@abstractmethod
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
"""Return weight tensor specifications.
Args:
input_size: Input dimension of the linear layer.
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors with
the correct shape and dtype.
"""
...
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
"""Return per-tensor parameter specifications (e.g., input_scale).
Args:
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
def get_perchannel_param(self, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
"""Return per-channel parameter specifications (e.g., weight_scale).
Args:
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
def get_pergroup_param(
self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
) -> dict[str, Any]:
"""Return per-group parameter specifications.
Args:
input_size: Input dimension of the linear layer.
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
layer_type: Type of layer (e.g., "row" for RowParallelLinear).
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0) -> torch.Tensor:
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, tp_rank: int | None = 0
) -> torch.Tensor:
"""Forward computation.
Args:
layer: The linear layer module.
x: Input tensor.
bias: Optional bias tensor.
tp_rank: Tensor parallel rank.
Returns:
Output tensor after quantized linear operation.
"""
@@ -117,42 +113,51 @@ class AscendLinearScheme(ABC):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing (transpose, format conversion, etc.).
Args:
layer: The linear layer module.
"""
pass
return
class AscendAttentionScheme(ABC):
"""Base class for all attention quantization schemes.
Subclasses must implement apply() method.
Other methods have default implementations.
"""
def create_weights(self, layer: torch.nn.Module) -> None:
"""Create weights for attention quantization.
Args:
layer: The attention layer module.
"""
pass
return
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing for attention layer.
Args:
layer: The attention layer module.
"""
pass
return
@abstractmethod
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache,
attn_metadata,
attn_type,
scale,
output,
) -> torch.Tensor:
"""Forward computation for attention layer.
Args:
layer: The attention layer module.
query: Query tensor.
@@ -163,7 +168,7 @@ class AscendAttentionScheme(ABC):
attn_type: Attention type.
scale: Scale factor.
output: Output tensor.
Returns:
Output tensor after attention computation.
"""
@@ -172,10 +177,10 @@ class AscendAttentionScheme(ABC):
class AscendMoEScheme(ABC):
"""Base class for all MoE quantization schemes.
Subclasses must implement get_weight(), get_dynamic_quant_param(),
and apply() methods.
Attributes:
quant_type: The quantization type for this scheme. Subclasses should
override this class attribute to declare their quant type.
@@ -185,35 +190,34 @@ class AscendMoEScheme(ABC):
quant_type: QuantType = QuantType.NONE
@abstractmethod
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
"""Return weight tensor specifications for MoE layer.
Args:
num_experts: Number of experts.
intermediate_size_per_partition: Intermediate size per partition.
hidden_sizes: Hidden dimension size.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
...
@abstractmethod
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
"""Return dynamic quantization parameters for MoE layer.
Args:
num_experts: Number of experts.
intermediate_size_per_partition: Intermediate size per partition.
hidden_sizes: Hidden dimension size.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
@@ -229,21 +233,21 @@ class AscendMoEScheme(ABC):
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
"""Forward computation for MoE layer.
Args:
layer: The MoE layer module.
x: Input hidden states.
@@ -264,7 +268,7 @@ class AscendMoEScheme(ABC):
log2phy: Logical to physical expert mapping.
global_redundant_expert_num: Number of redundant experts.
**kwargs: Additional keyword arguments.
Returns:
Output tensor after MoE computation.
"""
@@ -272,8 +276,8 @@ class AscendMoEScheme(ABC):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing for MoE layer.
Args:
layer: The MoE layer module.
"""
pass
return

View File

@@ -15,47 +15,47 @@
# limitations under the License.
#
from typing import Any, Dict, Optional, Tuple, Type
from typing import Any
# Registry: maps (quant_type, layer_type) -> SchemeClass
_SCHEME_REGISTRY: Dict[Tuple[str, str], Type[Any]] = {}
_SCHEME_REGISTRY: dict[tuple[str, str], type[Any]] = {}
def register_scheme(quant_type: str, layer_type: str):
"""Decorator to register a quantization scheme.
Args:
quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
layer_type: Layer type (e.g., "linear", "moe").
Returns:
Decorator function that registers the class.
Example:
@register_scheme("W8A8_DYNAMIC", "linear")
class W8A8DynamicLinearScheme(AscendLinearScheme):
...
"""
def decorator(cls: Type[Any]) -> Type[Any]:
def decorator(cls: type[Any]) -> type[Any]:
key = (quant_type, layer_type)
if key in _SCHEME_REGISTRY:
raise ValueError(
f"Scheme already registered for {quant_type}/{layer_type}: "
f"{_SCHEME_REGISTRY[key].__name__}")
f"Scheme already registered for {quant_type}/{layer_type}: {_SCHEME_REGISTRY[key].__name__}"
)
_SCHEME_REGISTRY[key] = cls
return cls
return decorator
def get_scheme_class(quant_type: str, layer_type: str) -> Optional[Type[Any]]:
def get_scheme_class(quant_type: str, layer_type: str) -> type[Any] | None:
"""Get scheme class for given quant_type and layer_type.
Args:
quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
layer_type: Layer type (e.g., "linear", "moe").
Returns:
The registered scheme class, or None if not found.
"""

View File

@@ -15,7 +15,8 @@
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
from collections.abc import Callable
from typing import Any
import torch
import torch_npu
@@ -56,8 +57,7 @@ def unpack_from_int32(
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[:, i::pack_factor] = (weight >>
(num_bits * i)) & mask
unpacked_weight[:, i::pack_factor] = (weight >> (num_bits * i)) & mask
original_row_size = int(shape[1])
unpacked_weight = unpacked_weight[:, :original_row_size]
else:
@@ -67,8 +67,7 @@ def unpack_from_int32(
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[i::pack_factor, :] = (weight >>
(num_bits * i)) & mask
unpacked_weight[i::pack_factor, :] = (weight >> (num_bits * i)) & mask
original_row_size = int(shape[0])
unpacked_weight = unpacked_weight[:original_row_size, :]
@@ -84,22 +83,17 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
:return: Packed tensor with int32 dtype optimized for storage
"""
assert weight.dim(
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
assert weight.dtype in [
torch.int8, torch.int32
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
assert weight.dim() == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
assert weight.dtype in [torch.int8, torch.int32], (
f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
)
if weight.dtype == torch.int32:
assert weight.shape[
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
weight.flatten(0, 1))
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
-1)
assert weight.shape[-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
packed_weight = torch_npu.npu_convert_weight_to_int4pack(weight.flatten(0, 1))
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], -1)
else:
assert weight.shape[
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
packed_weight = weight.view(torch.int32).contiguous()
return packed_weight
@@ -115,8 +109,7 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 32)
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
def get_weight(
@@ -125,22 +118,23 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
) -> dict[str, Any]:
assert intermediate_size_per_partition % self.pack_factor == 0, (
f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} "
f"can be divided by `pack_factor` {self.pack_factor}"
)
assert hidden_sizes % self.pack_factor == 0, (
f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
)
param_dict = {}
param_dict["w13_weight_packed"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.pack_factor,
dtype=torch.int32)
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.pack_factor, dtype=torch.int32
)
param_dict["w2_weight_packed"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.pack_factor,
dtype=torch.int32)
num_experts, hidden_sizes, intermediate_size_per_partition // self.pack_factor, dtype=torch.int32
)
return param_dict
@@ -150,38 +144,31 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
) -> dict[str, Any]:
assert intermediate_size_per_partition % self.group_size == 0, (
f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} "
f"can be divided by `group_size` {self.group_size}"
)
assert hidden_sizes % self.group_size == 0, (
f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
)
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.bfloat16
)
param_dict["w2_weight_scale"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
param_dict["w13_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w2_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.bfloat16
)
param_dict["w13_weight_shape"] = torch.empty(num_experts, 2, dtype=torch.int32)
param_dict["w2_weight_shape"] = torch.empty(num_experts, 2, dtype=torch.int32)
param_dict["w13_weight_offset"] = torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.bfloat16
)
param_dict["w2_weight_offset"] = torch.zeros(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.bfloat16
)
return param_dict
@@ -194,21 +181,22 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: Optional[torch.Tensor] = None,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -221,7 +209,8 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
@@ -241,38 +230,40 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask"),
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.transpose_weight:
w13_shape = layer.w13_weight_packed.data.shape
w2_shape = layer.w2_weight_packed.data.shape
unpacked_w13_weight = (unpack_from_int32(
layer.w13_weight_packed.data.flatten(0, 1),
torch.Size([
w13_shape[0] * w13_shape[1],
w13_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w13_shape[0], w13_shape[1],
-1).transpose(1, 2).contiguous().int())
unpacked_w2_weight = (unpack_from_int32(
layer.w2_weight_packed.data.flatten(0, 1),
torch.Size([
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w2_shape[0], w2_shape[1],
-1).transpose(1, 2).contiguous().int())
unpacked_w13_weight = (
unpack_from_int32(
layer.w13_weight_packed.data.flatten(0, 1),
torch.Size([w13_shape[0] * w13_shape[1], w13_shape[2] * self.pack_factor]),
self.num_bits,
)
.view(w13_shape[0], w13_shape[1], -1)
.transpose(1, 2)
.contiguous()
.int()
)
unpacked_w2_weight = (
unpack_from_int32(
layer.w2_weight_packed.data.flatten(0, 1),
torch.Size([w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor]),
self.num_bits,
)
.view(w2_shape[0], w2_shape[1], -1)
.transpose(1, 2)
.contiguous()
.int()
)
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2).contiguous()
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
1, 2).contiguous()
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
1, 2).contiguous()
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(1, 2).contiguous()
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(1, 2).contiguous()

View File

@@ -16,7 +16,7 @@
#
import math
from typing import Any, Dict, Optional, Tuple
from typing import Any
import torch
import torch_npu
@@ -31,8 +31,7 @@ def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
"""Pack int4 weights for NPU."""
original_device = weight_tensor.device
weight_tensor_npu = weight_tensor.npu()
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(
weight_tensor_npu.to(torch.int32), inner_k_tiles=1)
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(weight_tensor_npu.to(torch.int32), inner_k_tiles=1)
return weight_int4_packed.to(original_device)
@@ -58,22 +57,14 @@ def batched_kronecker_quant(
left_trans: torch.Tensor,
right_trans: torch.Tensor,
clip_ratio: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Batched Kronecker quantization with batch size limit handling."""
batch_tokens = x.shape[0]
if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE:
return torch_npu.npu_kronecker_quant(x,
left_trans,
right_trans,
clip_ratio=clip_ratio,
dst_dtype=torch.int32)
return torch_npu.npu_kronecker_quant(x, left_trans, right_trans, clip_ratio=clip_ratio, dst_dtype=torch.int32)
x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0)
processed_chunks = [
torch_npu.npu_kronecker_quant(chunk,
left_trans,
right_trans,
clip_ratio=clip_ratio,
dst_dtype=torch.int32)
torch_npu.npu_kronecker_quant(chunk, left_trans, right_trans, clip_ratio=clip_ratio, dst_dtype=torch.int32)
for chunk in x_chunks
]
quantized_list, scale_list = zip(*processed_chunks)
@@ -85,39 +76,32 @@ def batched_kronecker_quant(
@register_scheme("W4A4_FLATQUANT_DYNAMIC", "linear")
class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W4A4_FLATQUANT_DYNAMIC.
This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization.
- Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8 and packed to int32 during loading
- Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for distribution smoothing
- Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights
- Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for
distribution smoothing
- Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded
from external weights
"""
input_size = 0
def __init__(self):
self.sym = True
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
if input_size % 8 != 0:
raise ValueError(
f"input_size ({input_size}) must be divisible by 8 for int4 packing"
)
raise ValueError(f"input_size ({input_size}) must be divisible by 8 for int4 packing")
AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {}
left_trans_dim, right_trans_dim = get_decompose_dim(
AscendW4A4FlatQuantDynamicLinearMethod.input_size)
params_dict["left_trans"] = torch.empty(left_trans_dim,
left_trans_dim,
dtype=params_dtype)
params_dict["right_trans"] = torch.empty(right_trans_dim,
right_trans_dim,
dtype=params_dtype)
left_trans_dim, right_trans_dim = get_decompose_dim(AscendW4A4FlatQuantDynamicLinearMethod.input_size)
params_dict["left_trans"] = torch.empty(left_trans_dim, left_trans_dim, dtype=params_dtype)
params_dict["right_trans"] = torch.empty(right_trans_dim, right_trans_dim, dtype=params_dtype)
params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32)
return params_dict
@@ -125,22 +109,18 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
) -> dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=torch.float32)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=torch.float32)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32)
return params_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
original_dtype = x.dtype
input_shape = x.shape
@@ -156,18 +136,18 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
right_trans_matched = layer.right_trans.to(original_dtype)
x_reshaped = x.view(-1, left_dim, right_dim)
x_quantized_int4, activation_scale = batched_kronecker_quant(
x_reshaped, left_trans_matched, right_trans_matched,
layer.aclnn_clip_ratio)
x_quantized_reshaped = x_quantized_int4.view(-1,
left_dim * right_dim // 8)
x_reshaped, left_trans_matched, right_trans_matched, layer.aclnn_clip_ratio
)
x_quantized_reshaped = x_quantized_int4.view(-1, left_dim * right_dim // 8)
pertoken_scale = activation_scale.view(-1).to(torch.float32)
output = torch_npu.npu_quant_matmul(x_quantized_reshaped,
layer.weight_packed.t(),
layer.weight_scale.view(-1).to(
torch.float32),
pertoken_scale=pertoken_scale,
bias=None,
output_dtype=original_dtype)
output = torch_npu.npu_quant_matmul(
x_quantized_reshaped,
layer.weight_packed.t(),
layer.weight_scale.view(-1).to(torch.float32),
pertoken_scale=pertoken_scale,
bias=None,
output_dtype=original_dtype,
)
output = output.view(*input_shape[:-1], -1)
if bias is not None:
output = output + bias.to(original_dtype)
@@ -176,15 +156,11 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
def process_weights_after_loading(self, layer):
# NOTE: Currently, w4a4 can't support weight nz
weight_packed = pack_int4_weights(layer.weight.data)
layer.register_parameter(
'weight_packed',
torch.nn.Parameter(weight_packed, requires_grad=False))
layer.register_parameter("weight_packed", torch.nn.Parameter(weight_packed, requires_grad=False))
del layer.weight
layer.weight_scale.data = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.to(torch.float32)
layer.left_trans = torch.nn.Parameter(
layer.left_trans.data.t().contiguous())
layer.left_trans = torch.nn.Parameter(layer.left_trans.data.t().contiguous())
layer.right_trans = torch.nn.Parameter(layer.right_trans.data)
layer.clip_ratio = torch.nn.Parameter(
layer.clip_ratio.data.to(torch.float32))
layer.clip_ratio = torch.nn.Parameter(layer.clip_ratio.data.to(torch.float32))
layer.aclnn_clip_ratio = layer.clip_ratio.item()

View File

@@ -15,7 +15,7 @@
# limitations under the License.
#
from typing import Any, Dict, Optional
from typing import Any
import torch
import torch_npu
@@ -27,7 +27,7 @@ from .registry import register_scheme
@register_scheme("W4A4_DYNAMIC", "linear")
class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W4A4_DYNAMIC.
This class implements W4A4 quantization with LAOS approach and dynamic activation quantization.
- Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8.
- Activation: 4-bit dynamic quantization.
@@ -37,7 +37,7 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
self.transpose_weight = True
self.rotation_type = None
def set_rotation_config(self, prefix: str, metadata: Dict) -> Optional[str]:
def set_rotation_config(self, prefix: str, metadata: dict) -> str | None:
"""Set rotation config based on prefix and metadata."""
layer_idx = prefix.split(".")[2]
if prefix.endswith("o_proj"):
@@ -50,34 +50,22 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
return "kronecker_rotation"
return None
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
def get_perchannel_param(self, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=torch.float32)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=torch.float32)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32)
if self.rotation_type == "heads_rotation":
params_dict["heads_rotation"] = torch.zeros((64, 64),
dtype=torch.float32)
params_dict["heads_rotation"] = torch.zeros((64, 64), dtype=torch.float32)
if self.rotation_type == "kronecker_rotation":
params_dict["kronecker_rotation_n"] = torch.zeros(
(160, 160), dtype=torch.float32)
params_dict["kronecker_rotation_m"] = torch.zeros(
(160, 160), dtype=torch.float32)
params_dict["kronecker_rotation_n"] = torch.zeros((160, 160), dtype=torch.float32)
params_dict["kronecker_rotation_m"] = torch.zeros((160, 160), dtype=torch.float32)
return params_dict
def apply_rotation(self, layer: torch.nn.Module,
x: torch.Tensor) -> torch.Tensor:
def apply_rotation(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
"""Apply rotation transformation to input tensor."""
init_shape = x.shape
dtype = x.dtype
@@ -100,8 +88,8 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
dtype = x.dtype
x, pertoken_scale = torch_npu.npu_dynamic_quant(x, dst_type=torch.quint4x2)
@@ -113,14 +101,14 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
scale=layer.weight_scale.data.view(-1),
pertoken_scale=pertoken_scale,
bias=None,
output_dtype=dtype)
output_dtype=dtype,
)
if bias is not None:
output = output + bias.to(dtype)
return output
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale.data = layer.weight_scale.data.to(torch.float32)
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32))
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(-1, -2)

View File

@@ -15,7 +15,8 @@
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
from collections.abc import Callable
from typing import Any
import numpy as np
import torch
@@ -27,7 +28,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import maybe_trans_nz, COMPRESSED_TENSORS_METHOD
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
@@ -39,19 +40,17 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
def __init__(self):
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256)
quant_version = vllm_config.quant_config.quant_description.get("version", "0")
self.new_quant_version = quant_version == "1.0.0"
from vllm.distributed import get_tensor_model_parallel_world_size
self.tp_size = get_tensor_model_parallel_world_size()
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
"""Create weight parameters.
For new quantization version (double int4 pack into int8), the output dimension
is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned
dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader.
@@ -62,40 +61,26 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
# double int4 pack into int8: output dimension is compressed
pack_factor = 2
actual_output_size = output_size // pack_factor
params_dict["weight"] = torch.empty(actual_output_size,
input_size,
dtype=torch.int8)
params_dict["weight"] = torch.empty(actual_output_size, input_size, dtype=torch.int8)
# Add packing information for vLLM's weight_loader
params_dict["_packed_dim"] = 0
params_dict["_packed_factor"] = pack_factor
else:
params_dict["weight"] = torch.empty(output_size,
input_size,
dtype=torch.int8)
params_dict["weight"] = torch.empty(output_size, input_size, dtype=torch.int8)
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
def get_pergroup_param(
self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
) -> dict[str, Any]:
"""Create per-group quantization parameters."""
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_scale_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_scale_second"] = torch.empty(output_size, input_size // self.group_size, dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(
output_size, input_size // self.group_size, dtype=params_dtype
)
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
@@ -103,24 +88,21 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
if self.new_quant_version:
scale_bias_dim = 16 if layer_type == "row" else 1
params_dict["scale_bias"] = torch.empty(output_size,
scale_bias_dim,
dtype=torch.float32)
params_dict["scale_bias"] = torch.empty(output_size, scale_bias_dim, dtype=torch.float32)
return params_dict
@staticmethod
def process_scale_second(weight: torch.Tensor,
scale: torch.Tensor,
per_group_scale: torch.Tensor,
is_new_quant: bool = False):
def process_scale_second(
weight: torch.Tensor, scale: torch.Tensor, per_group_scale: torch.Tensor, is_new_quant: bool = False
):
"""Process the scale for second-level quantization.
Args:
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
scale: first-level quantization scale [output_size]
per_group_scale: second-level per-group quantization scale [group_num, n_scale]
is_new_quant: whether it's the new quantization version (weight already compressed)
Returns:
(antiquant_scale, bias): dequantization scale and bias (bias=None for new version)
"""
@@ -133,8 +115,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
bias = None
if not is_new_quant:
weight_high = weight.to(torch.float32).reshape(
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight.to(torch.float32).reshape(group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
# NOTE: scale_bias is not used currently
@@ -148,8 +129,8 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
bias: torch.Tensor | None = None,
tp_rank: int | None = None,
) -> torch.Tensor:
return torch_npu.npu_weight_quant_batchmatmul(
x,
@@ -161,8 +142,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
def process_weights_after_loading(self, layer: torch.nn.Module):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = maybe_trans_nz(layer.weight.data)
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
torch.float32)
layer.weight_scale.data = layer.weight_scale.data.flatten().to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight_scale_second.data, scale_bias = self.process_scale_second(
layer.weight.data,
@@ -187,15 +167,14 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
if self.new_quant_version:
# weights on disk are already in packed int4 format
# pack 4 int8(int4*2) to int32
assert layer.weight.data.shape[-1] % 4 == 0, \
assert layer.weight.data.shape[-1] % 4 == 0, (
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
layer.weight.data = layer.weight.data.view(
torch.int32).contiguous()
)
layer.weight.data = layer.weight.data.view(torch.int32).contiguous()
else:
# weights are not compressed
# need to be packed via npu_convert_weight_to_int4pack
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32))
@register_scheme("W4A8_DYNAMIC", "moe")
@@ -209,69 +188,56 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256)
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
self.is_per_channel_weight = self.group_size == 0
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
quant_version = vllm_config.quant_config.quant_description.get("version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8
self.new_quant_version = quant_version == "1.0.0"
self.quant_method = vllm_config.quant_config.quant_description.get(
"ascend_quant_method", "")
self.quant_method = vllm_config.quant_config.quant_description.get("ascend_quant_method", "")
if self.quant_method == COMPRESSED_TENSORS_METHOD:
self.weight_strategy = vllm_config.quant_config.quant_description.get(
"weight_strategy", "group")
self.weight_strategy = vllm_config.quant_config.quant_description.get("weight_strategy", "group")
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
if self.new_quant_version and self.tp_size > 16:
raise ValueError(
"The current weight does not support moe part tp>16.")
raise ValueError("The current weight does not support moe part tp>16.")
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
if self.quant_method == COMPRESSED_TENSORS_METHOD:
return self.get_weight_compressed_tensors(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
)
else:
return self.get_weight_modelslim(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
def get_weight_compressed_tensors(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return self.get_weight_modelslim(num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype)
def get_weight_compressed_tensors(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
E = num_experts
H = hidden_sizes
IN = intermediate_size_per_partition
g = self.group_size
param_dict["w13_weight"] = torch.empty(E, 2 * IN, H,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(E, H, IN,
dtype=torch.int8)
param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(E, H, IN, dtype=torch.int8)
return param_dict
def get_weight_modelslim(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight_modelslim(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
if self.new_quant_version:
w13_output_size = intermediate_size_per_partition
@@ -280,33 +246,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
w13_output_size = 2 * intermediate_size_per_partition
w2_output_size = hidden_sizes
param_dict["w13_weight"] = torch.empty(num_experts,
w13_output_size,
hidden_sizes,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(num_experts,
w2_output_size,
intermediate_size_per_partition,
dtype=torch.int8)
param_dict["w13_weight"] = torch.empty(num_experts, w13_output_size, hidden_sizes, dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(
num_experts, w2_output_size, intermediate_size_per_partition, dtype=torch.int8
)
return param_dict
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
if self.quant_method == COMPRESSED_TENSORS_METHOD:
return self.get_dynamic_quant_param_compressed_tensors(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
)
else:
return self.get_dynamic_quant_param_modelslim(
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
def get_dynamic_quant_param_compressed_tensors(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
)
def get_dynamic_quant_param_compressed_tensors(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
E = num_experts
@@ -318,72 +278,48 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
def _n_scale_cols(in_features: int) -> int:
return 1 if g <= 0 else (in_features // g)
param_dict["w13_weight_scale"] = torch.empty(
E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
param_dict["w13_weight_scale"] = torch.empty(E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN),
dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), dtype=torch.bfloat16)
return param_dict
def get_dynamic_quant_param_modelslim(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_dynamic_quant_param_modelslim(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32)
if not self.is_per_channel_weight:
param_dict["w13_weight_scale_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32
)
param_dict["w13_weight_offset_second"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32
)
param_dict["w2_weight_scale_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32
)
param_dict["w2_weight_offset_second"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32
)
if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w2_scale_bias"] = torch.empty(num_experts,
hidden_sizes,
16 // self.tp_size,
dtype=torch.float32)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
)
param_dict["w2_scale_bias"] = torch.empty(
num_experts, hidden_sizes, 16 // self.tp_size, dtype=torch.float32
)
return param_dict
@@ -396,21 +332,22 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
)
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts(
@@ -424,18 +361,17 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
random_matrix = torch.rand(
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
topk_weights = topk_weights.to(x.dtype)
@@ -446,25 +382,23 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
w2=[layer.w2_weight],
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr(
layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(
layer, "w2_scale_bias") else None,
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a8=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask"),
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
scale = scale.transpose(1, 2).contiguous()
if self.is_per_channel_weight:
scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
np.int64)).npu()
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
return scale_uint64_tensor, None
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
group_num, k, n = weight.shape
@@ -475,32 +409,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
group_num, quantgroup_num, n = per_group_scale.shape
bias = None
if not self.new_quant_version:
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight.to(torch.float32).reshape(
[group_num, quantgroup_num, -1, n]
) * per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight_high.reshape([group_num, k, n])
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
torch.float32)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(torch.float32)
scale_fp32_np = scale_fp32.cpu().numpy()
scale_fp32_np.dtype = np.uint32
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
dtype=np.uint32)
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), dtype=np.uint32)
sscale_uint64[..., ::2] = scale_fp32_np
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
dtype=np.int64).copy()
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
group_num, quantgroup_num, n)
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), dtype=np.int64).copy()
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(group_num, quantgroup_num, n)
sscale_uint64_tensor = sscale_uint64_tensor.npu()
return sscale_uint64_tensor, bias
def update_bias(self, layer, w13_bias, w2_bias):
if self.new_quant_version:
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)
else:
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias)
@@ -510,13 +439,12 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version:
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
assert weight.shape[
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
return weight.view(torch.int32).contiguous()
else:
return torch_npu.npu_quantize(weight.to(torch.float32),
torch.tensor([1.]).npu(), None,
torch.quint4x2, -1, False)
return torch_npu.npu_quantize(
weight.to(torch.float32), torch.tensor([1.0]).npu(), None, torch.quint4x2, -1, False
)
def process_weights_after_loading(self, layer):
if self.quant_method == COMPRESSED_TENSORS_METHOD:
@@ -524,23 +452,18 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
else:
self.process_weights_after_loading_modelslim(layer)
def process_weights_after_loading_compressed_tensors(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
def process_scale_compressed_tensors(scale: torch.Tensor):
scale = scale.transpose(1, 2).to(torch.float32).contiguous()
scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
np.int64)).npu()
scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
return scale_uint64_tensor
def update_bias_compressed_tensors(weight: torch.Tensor,
scale: torch.Tensor, strategy:str):
def update_bias_compressed_tensors(weight: torch.Tensor, scale: torch.Tensor, strategy: str):
group_num, k, n = weight.shape
scale = scale.transpose(1, 2).contiguous()
scale = scale.reshape(group_num, -1, n)
@@ -548,8 +471,9 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
bias = None
if strategy == "group":
tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
scale.reshape([group_num, quantgroup_num, 1, n])
tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * scale.reshape(
[group_num, quantgroup_num, 1, n]
)
tmp = tmp.reshape([group_num, k, n])
bias = 8 * tmp.sum(axis=1)
elif strategy == "channel":
@@ -558,19 +482,14 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
raise ValueError(f"Unsupported weight strategy: {strategy}")
return bias
w13_bias = update_bias_compressed_tensors(layer.w13_weight.data,
layer.w13_weight_scale.data,
self.weight_strategy)
w2_bias = update_bias_compressed_tensors(layer.w2_weight.data,
layer.w2_weight_scale.data,
self.weight_strategy)
w13_bias = update_bias_compressed_tensors(
layer.w13_weight.data, layer.w13_weight_scale.data, self.weight_strategy
)
w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, layer.w2_weight_scale.data, self.weight_strategy)
layer.w13_weight_scale.data = process_scale_compressed_tensors(
layer.w13_weight_scale.data)
layer.w2_weight_scale.data = process_scale_compressed_tensors(
layer.w2_weight_scale.data)
layer.w13_weight_scale.data = process_scale_compressed_tensors(layer.w13_weight_scale.data)
layer.w2_weight_scale.data = process_scale_compressed_tensors(layer.w2_weight_scale.data)
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias)
w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
@@ -583,21 +502,19 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
def process_weights_after_loading_modelslim(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
layer, "w13_weight_scale_second") else None
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
layer, "w2_weight_scale_second") else None
w13_weight_scale_second = (
layer.w13_weight_scale_second.data if hasattr(layer, "w13_weight_scale_second") else None
)
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(layer, "w2_weight_scale_second") else None
layer.w13_weight_scale.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data,
w13_weight_scale_second)
layer.w13_weight, layer.w13_weight_scale.data, w13_weight_scale_second
)
layer.w2_weight_scale.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data,
w2_weight_scale_second)
layer.w2_weight, layer.w2_weight_scale.data, w2_weight_scale_second
)
if hasattr(layer, "w13_weight_scale_second"):
# scale_second is no longer used, release this part of the memory
del layer.w13_weight_scale_second

View File

@@ -15,7 +15,7 @@
# limitations under the License.
#
from typing import Any, Dict, Optional
from typing import Any
import torch
import torch_npu
@@ -29,7 +29,7 @@ from .registry import register_scheme
@register_scheme("W8A16", "linear")
class AscendW8A16LinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A16.
This scheme uses 8-bit quantized weights with 16-bit activations.
"""
@@ -41,39 +41,34 @@ class AscendW8A16LinearMethod(AscendLinearScheme):
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
) -> dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
) -> dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
output = torch_npu.npu_weight_quant_batchmatmul(
x=x,
weight=layer.weight,
antiquant_scale=layer.weight_scale,
antiquant_offset=layer.weight_offset,
bias=bias)
bias=bias,
)
return output
def process_weights_after_loading(self, layer):

View File

@@ -15,7 +15,8 @@
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
from collections.abc import Callable
from typing import Any
import torch
import torch_npu
@@ -28,8 +29,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
zero_experts_compute)
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
@@ -39,16 +39,17 @@ from .registry import register_scheme
def scale_from_float_to_int64(scale):
"""Convert float32 scale to int64 representation."""
import numpy as np
scale = torch.from_numpy(
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
dtype=np.int32).astype(np.int64)).to(scale.device)
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64)
).to(scale.device)
return scale
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A8_DYNAMIC.
This scheme uses dynamic per-token quantization for activations
and per-channel quantization for weights.
"""
@@ -56,33 +57,26 @@ class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
def __init__(self):
pass
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
) -> dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
output = torch_npu.npu_quant_matmul(
@@ -116,9 +110,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
self.use_aclgraph = (vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager)
self.use_aclgraph = (
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager
)
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
@@ -130,49 +125,34 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
except AttributeError:
self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight"] = torch.empty(num_experts,
2 *
intermediate_size_per_partition,
hidden_sizes,
dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(num_experts,
hidden_sizes,
intermediate_size_per_partition,
dtype=torch.int8)
param_dict["w13_weight"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.int8
)
param_dict["w2_weight"] = torch.empty(
num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.int8
)
return param_dict
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype
)
param_dict["w13_weight_offset"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=params_dtype)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype
)
param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype)
param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype)
return param_dict
def apply(
@@ -184,25 +164,26 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Optional[Any] = None,
pertoken_scale: Any | None = None,
**kwargs,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
if zero_expert_num == 0 or zero_expert_type is None:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
)
if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context()
@@ -222,7 +203,8 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
assert topk_ids is not None
assert topk_weights is not None
if zero_expert_num > 0 and zero_expert_type is not None:
@@ -237,12 +219,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
random_matrix = torch.rand(
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
assert topk_weights is not None
topk_weights = topk_weights.to(self.in_dtype)
@@ -259,9 +239,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
fused_scale_flag = (get_forward_context().moe_comm_type
== MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
fused_scale_flag = (
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
)
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
@@ -275,54 +256,35 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask"),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
return final_hidden_states
def process_weights_after_loading(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
# TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant`
# can only support weight nz.
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
torch.float32)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(torch.float32)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
layer.fused_w1_scale = scale_from_float_to_int64(
layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(
layer.w2_weight_scale.data)
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
if self.dynamic_eplb:
layer.w13_weight_list = [
weight.clone()
for weight in layer.w13_weight.data.unbind(dim=0)
]
layer.w2_weight_list = [
weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
]
layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)]
layer.w2_weight_list = [weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)]
layer.w13_weight_scale_fp32_list = [
weight.clone()
for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
]
layer.w2_weight_scale_list = [
weight.clone()
for weight in layer.w2_weight_scale.data.unbind(dim=0)
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
]
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale

View File

@@ -15,7 +15,7 @@
# limitations under the License.
#
from typing import Any, Dict, Optional
from typing import Any
import torch
import torch_npu
@@ -28,48 +28,37 @@ from .registry import register_scheme
@register_scheme("W8A8_MXFP8", "linear")
class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization.
This scheme uses microscaling FP8 quantization with per-group scales.
The activation is dynamically quantized to FP8 (E4M3FN format) with
microscaling, and weights are stored in FP8 format with per-group scales.
"""
model_dtype = None
def __init__(self):
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 32)
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight":
torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)
}
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)}
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
def get_pergroup_param(
self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
) -> dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=torch.uint8)
params_dict["weight_scale"] = torch.empty(output_size, input_size // self.group_size, dtype=torch.uint8)
return params_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(
x, dst_type=torch.float8_e4m3fn)
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn)
pertoken_scale = dynamic_scale
output_dtype = x.dtype
@@ -82,13 +71,13 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias,
output_dtype=output_dtype,
group_sizes=[1, 1, self.group_size])
group_sizes=[1, 1, self.group_size],
)
return output
def process_weights_after_loading(self, layer):
n_dim, k_dim = layer.weight_scale.data.shape
layer.weight_scale.data = layer.weight_scale.data.reshape(
n_dim, k_dim // 2, 2)
layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2)
layer.weight.data = layer.weight.data.transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)

View File

@@ -22,15 +22,14 @@ for prefill and decode phases:
- Decode (KV consumer): Uses static W8A8 quantization
"""
from typing import Any, Dict, Optional
from typing import Any
import torch
from vllm.config import get_current_vllm_config
from .base import AscendLinearScheme
from .registry import register_scheme
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
from .w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod
from .w8a8_static import AscendW8A8LinearMethod
@@ -53,31 +52,27 @@ class AscendW8A8PDMixLinearMethod(AscendLinearScheme):
self._dynamic_method = AscendW8A8DynamicLinearMethod()
kv_transfer_config = get_current_vllm_config().kv_transfer_config
self._is_kv_consumer = (kv_transfer_config is not None
and kv_transfer_config.is_kv_consumer)
self._is_kv_consumer = kv_transfer_config is not None and kv_transfer_config.is_kv_consumer
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return self._static_method.get_weight(input_size, output_size,
params_dtype)
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
return self._static_method.get_weight(input_size, output_size, params_dtype)
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
return self._static_method.get_pertensor_param(params_dtype)
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
return self._static_method.get_perchannel_param(
output_size, params_dtype)
) -> dict[str, Any]:
return self._static_method.get_perchannel_param(output_size, params_dtype)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
if layer.is_kv_consumer:
return self._static_method.apply(layer, x, bias, tp_rank)
@@ -92,26 +87,15 @@ class AscendW8A8PDMixLinearMethod(AscendLinearScheme):
@register_scheme("W8A8_MIX", "moe")
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = super().get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_sizes,
params_dtype)
param_dict["w2_deq_scale"] = torch.empty(num_experts,
hidden_sizes,
dtype=torch.float32)
param_dict["w13_deq_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32)
param_dict["w2_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
param_dict["w13_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
)
param_dict["w2_deq_scale"] = torch.empty(num_experts, hidden_sizes, dtype=torch.float32)
param_dict["w13_deq_scale"] = torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32)
param_dict["w2_input_offset"] = torch.empty(num_experts, 1, dtype=torch.int8)
param_dict["w13_input_offset"] = torch.empty(num_experts, 1, dtype=torch.int8)
return param_dict

View File

@@ -15,14 +15,16 @@
# limitations under the License.
#
from typing import Any, Dict, Optional
from typing import Any
import torch
import torch_npu
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type,
get_weight_prefetch_method, maybe_trans_nz)
from vllm_ascend.utils import (
COMPRESSED_TENSORS_METHOD,
get_weight_prefetch_method,
maybe_trans_nz,
)
from .base import AscendLinearScheme
from .registry import register_scheme
@@ -44,13 +46,11 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
) -> dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
@@ -60,29 +60,23 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
) -> dict[str, Any]:
params_dict = {}
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
if params_dtype == torch.bfloat16:
params_dict["deq_scale"] = torch.empty(output_size,
dtype=torch.float32)
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
elif params_dtype == torch.float16:
params_dict["deq_scale"] = torch.empty(output_size,
dtype=torch.int64)
params_dict["weight_scale"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
layer_cls_name = layer.__class__.__name__
@@ -95,15 +89,15 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
start_flag=x,
)
try:
quant_comm_config = getattr(layer, "_quant_comm_config")
quant_comm_config = layer._quant_comm_config
except AttributeError:
quant_comm_config = {}
comm_fn = quant_comm_config.get("communication_fn")
enable_flashcomm2_quant_comm = comm_fn is not None and (
"o_proj" in layer.prefix or "out_proj" in layer.prefix)
"o_proj" in layer.prefix or "out_proj" in layer.prefix
)
if enable_flashcomm2_quant_comm:
quant_input_x = x.contiguous().view(
-1, layer.aclnn_input_scale_reciprocal.size(0))
quant_input_x = x.contiguous().view(-1, layer.aclnn_input_scale_reciprocal.size(0))
quant_x = torch.ops.vllm.quantize(
quant_input_x,
layer.aclnn_input_scale,
@@ -132,7 +126,7 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
quant_bias = layer.quant_bias if tp_rank == 0 else None
try:
ascend_quant_method = getattr(layer, "ascend_quant_method")
ascend_quant_method = layer.ascend_quant_method
except AttributeError:
ascend_quant_method = ""
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
@@ -150,14 +144,14 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.input_scale.data.repeat(expanding_factor), requires_grad=False
)
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False)
layer.input_scale.data.repeat(expanding_factor), requires_grad=False
)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor),
requires_grad=False).to(layer.aclnn_input_scale.dtype)
layer.input_offset.data.repeat(expanding_factor), requires_grad=False
).to(layer.aclnn_input_scale.dtype)
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = maybe_trans_nz(layer.weight.data)
@@ -166,5 +160,4 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
ascend_quant_method = getattr(layer, "ascend_quant_method", "")
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
deq_scale = layer.input_scale.data * layer.weight_scale.data
layer.deq_scale = torch.nn.Parameter(deq_scale,
requires_grad=False)
layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False)

View File

@@ -21,20 +21,18 @@ This module provides the AscendModelSlimConfig class for parsing quantization
configs generated by the ModelSlim tool, along with model-specific mappings.
"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase
from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding
from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
@@ -45,7 +43,7 @@ logger = init_logger(__name__)
# key: model_type
# value: orig_to_new_prefix
QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = {
QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = {
"qwen3_vl_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
@@ -60,7 +58,7 @@ QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = {
# key: model_type
# value: dict of fused module name -> list of original module names
packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
packed_modules_model_mapping: dict[str, dict[str, list[str]]] = {
"qwen3_moe": {
"qkv_proj": [
"q_proj",
@@ -71,52 +69,44 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"pangu_ultra_moe": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"pangu_ultra_moe_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"qwen3_next": {
"qkv_proj": [
@@ -126,8 +116,7 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"qwen2_5_vl": {
"qkv_proj": [
@@ -150,8 +139,7 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4_moe": {
"qkv_proj": [
@@ -163,20 +151,17 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4_moe_lite": {
"glm4_moe_lite": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"longcat_flash": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
"experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
},
"minimax_m2": {
"qkv_proj": [
@@ -184,17 +169,17 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"k_proj",
"v_proj",
],
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"]
}
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"],
},
}
def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]:
def get_packed_modules_mapping(model_type: str) -> dict[str, list[str]]:
"""Get packed modules mapping for a model type.
Args:
model_type: The model type string (e.g., "deepseek_v3").
Returns:
Dictionary mapping fused module names to their component module names.
Returns empty dict if model_type is not found.
@@ -202,12 +187,12 @@ def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]:
return packed_modules_model_mapping.get(model_type, {})
def get_prefix_mapping(model_type: str) -> Dict[str, str]:
def get_prefix_mapping(model_type: str) -> dict[str, str]:
"""Get prefix mapping for a model type.
Args:
model_type: The model type string (e.g., "qwen3_vl_moe").
Returns:
Dictionary mapping original prefixes to new prefixes.
Returns empty dict if model_type is not found.
@@ -216,15 +201,15 @@ def get_prefix_mapping(model_type: str) -> Dict[str, str]:
def get_linear_quant_type(
quant_description: Dict[str, Any], prefix: str,
packed_modules_mapping: Dict[str, Any]) -> Optional[str]:
quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any]
) -> str | None:
"""Determine the quantization type for a linear layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
The quantization type string (e.g., "W8A8_DYNAMIC").
"""
@@ -232,11 +217,10 @@ def get_linear_quant_type(
if proj_name in packed_modules_mapping:
quant_type = None
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in packed_modules_mapping[proj_name]
prefix.replace(proj_name, shard_proj_name) for shard_proj_name in packed_modules_mapping[proj_name]
]
for shard_prefix in shard_prefixes:
shard_quant_type = quant_description[shard_prefix + '.weight']
shard_quant_type = quant_description[shard_prefix + ".weight"]
if quant_type is None:
quant_type = shard_quant_type
@@ -244,72 +228,68 @@ def get_linear_quant_type(
raise ValueError(
f"Not all shards of {prefix} are quantized with same quant type."
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
f"use {quant_type}. Please check quantization config.")
f"use {quant_type}. Please check quantization config."
)
else:
quant_type = quant_description[prefix + '.weight']
quant_type = quant_description[prefix + ".weight"]
return quant_type
def get_quant_type_for_layer(
quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str,
Any]] = None) -> Optional[str]:
quant_description: dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: dict[str, Any] | None = None,
) -> str | None:
"""Determine the quantization type for a layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
layer_type: The type of layer ("linear", "moe", "attention").
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
The quantization type string (e.g., "W8A8_DYNAMIC").
"""
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention
if layer_type == "attention" and 'fa_quant_type' in quant_description.keys(
):
return quant_description['fa_quant_type']
if layer_type == "attention" and "fa_quant_type" in quant_description:
return quant_description["fa_quant_type"]
# Linear / MoE
return get_linear_quant_type(quant_description, prefix,
packed_modules_mapping)
return get_linear_quant_type(quant_description, prefix, packed_modules_mapping)
def create_scheme_for_layer(
quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None):
quant_description: dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: dict[str, Any] | None = None,
):
"""Create a quantization scheme instance for a layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
layer_type: The type of layer ("linear", "moe", "attention").
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
An instance of the appropriate quantization scheme class.
"""
logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
quant_type = get_quant_type_for_layer(quant_description, prefix,
layer_type, packed_modules_mapping)
quant_type = get_quant_type_for_layer(quant_description, prefix, layer_type, packed_modules_mapping)
if quant_type is None:
raise ValueError(
f"Could not determine quantization type for layer {prefix}.")
raise ValueError(f"Could not determine quantization type for layer {prefix}.")
# Use registry to get scheme class
scheme_cls = get_scheme_class(quant_type, layer_type)
if scheme_cls is not None:
return scheme_cls()
raise NotImplementedError(
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
)
raise NotImplementedError(f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}.")
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
@@ -321,13 +301,13 @@ class AscendModelSlimConfig(QuantizationConfig):
quantized using the ModelSlim tool.
"""
def __init__(self, quant_config: Dict[str, Any]):
def __init__(self, quant_config: dict[str, Any]):
super().__init__()
self.quant_description = quant_config
# TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description.keys():
for k in self.quant_description:
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
@@ -344,25 +324,23 @@ class AscendModelSlimConfig(QuantizationConfig):
return ASCEND_QUANTIZATION_METHOD
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")
raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.')
@classmethod
def get_config_filenames(cls) -> List[str]:
def get_config_filenames(cls) -> list[str]:
return ["quant_model_description.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AscendModelSlimConfig":
def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig":
return cls(config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None:
if hf_quant_cfg is not None:
quant_method = hf_quant_cfg.get("quant_method", None)
if not quant_method and torch.npu.is_available():
@@ -373,15 +351,17 @@ class AscendModelSlimConfig(QuantizationConfig):
# TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix=prefix_mapping)
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping)
return hf_to_vllm_mapper._map_name(prefix)
return prefix
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from .method_adapters import (AscendEmbeddingMethod, AscendFusedMoEMethod,
AscendKVCacheMethod, AscendLinearMethod)
def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
from .method_adapters import (
AscendEmbeddingMethod,
AscendFusedMoEMethod,
AscendKVCacheMethod,
AscendLinearMethod,
)
vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_config.model_type
@@ -390,81 +370,67 @@ class AscendModelSlimConfig(QuantizationConfig):
# Adapt to Minimax architecture: update layer names to MoE convention
prefix = prefix.replace("mlp", "block_sparse_moe")
# Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts')
parts = prefix.split('.')
parts = prefix.split(".")
if "experts" in parts and len(parts) > 2:
exp_idx = parts.index("experts")
if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit():
parts = parts[:exp_idx + 1]
parts = parts[: exp_idx + 1]
prefix = ".".join(parts)
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[
model_type]
self.packed_modules_mapping = packed_modules_model_mapping[model_type]
prefix = self.quant_prefix_mapper(model_type, prefix)
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("v0.15.0"):
from vllm.attention.layer import Attention # type: ignore
from vllm.attention.layer import Attention # type: ignore
else:
from vllm.model_executor.layers.attention import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
prefix = prefix.split(".", 1)[-1]
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
# Delayed import to avoid circular import
from vllm_ascend.ops.linear import \
AscendUnquantizedLinearMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
return AscendUnquantizedLinearMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix,
"linear",
self.packed_modules_mapping)
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
return AscendLinearMethod(scheme)
elif isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys() and \
self.quant_description['fa_quant_type'] is not None:
scheme = create_scheme_for_layer(self.quant_description, prefix,
"attention",
self.packed_modules_mapping)
elif (
isinstance(layer, Attention)
and "fa_quant_type" in self.quant_description
and self.quant_description["fa_quant_type"] is not None
):
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
return AscendKVCacheMethod(scheme)
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
# Delayed import to avoid circular import
from vllm_ascend.ops.fused_moe.fused_moe import \
AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
scheme = create_scheme_for_layer(self.quant_description, prefix,
"moe",
self.packed_modules_mapping)
scheme = create_scheme_for_layer(self.quant_description, prefix, "moe", self.packed_modules_mapping)
return AscendFusedMoEMethod(scheme, layer.moe_config)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix,
"linear",
self.packed_modules_mapping)
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
return AscendEmbeddingMethod(scheme)
return None
def is_layer_skipped_ascend(
self,
prefix: str,
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
def is_layer_skipped_ascend(self, prefix: str, fused_mapping: Mapping[str, list[str]] = MappingProxyType({})):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
prefix.replace(proj_name, shard_proj_name) for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix +
'.weight'] == "FLOAT"
is_shard_skipped = self.quant_description[shard_prefix + ".weight"] == "FLOAT"
if is_skipped is None:
is_skipped = is_shard_skipped
@@ -472,12 +438,13 @@ class AscendModelSlimConfig(QuantizationConfig):
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
"to have the same precision."
)
else:
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> List[str]:
def get_scaled_act_names(self) -> list[str]:
return []

View File

@@ -1,18 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from vllm.triton_utils import HAS_TRITON, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN,
PLACEHOLDER_TOKEN_ID,
generate_uniform_probs)
from vllm.v1.sample.rejection_sampler import (
GREEDY_TEMPERATURE,
MAX_SPEC_LEN,
PLACEHOLDER_TOKEN_ID,
generate_uniform_probs,
)
from vllm_ascend.ops.triton.reject_sample import (
cal_grid_and_block_size, expand_triton,
cal_grid_and_block_size,
expand_triton,
rejection_greedy_sample_with_triton,
rejection_random_sample_block_verify_kernel,
rejection_random_sample_kernel, sample_recovered_tokens_kernel)
rejection_random_sample_kernel,
sample_recovered_tokens_kernel,
)
from vllm_ascend.sample.sampler import apply_top_k_top_p
@@ -83,7 +88,7 @@ def rejection_sample(
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor],
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
@@ -126,15 +131,20 @@ def rejection_sample(
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if HAS_TRITON:
rejection_greedy_sample_with_triton(output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids, target_argmax,
bonus_token_ids, is_greedy,
max_spec_len, grid, block_size)
rejection_greedy_sample_with_triton(
output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
grid,
block_size,
)
else:
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
@@ -179,7 +189,7 @@ def rejection_sample(
if not using_block_verify:
# Rejection sampling for random sampling requests.
if HAS_TRITON:
rejection_random_sample_kernel[(grid, )](
rejection_random_sample_kernel[(grid,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -214,7 +224,7 @@ def rejection_sample(
else:
# MagicMTP: Improving acceptance rate with Block Verify.
if HAS_TRITON:
rejection_random_sample_block_verify_kernel[(grid, )](
rejection_random_sample_block_verify_kernel[(grid,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -231,19 +241,20 @@ def rejection_sample(
BLOCK_SIZE=block_size,
)
else:
rejection_random_sample_block_verify_pytorch(output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs
is None)
rejection_random_sample_block_verify_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
)
return output_token_ids
@@ -277,13 +288,7 @@ def expand_batch_to_tokens(
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
if HAS_TRITON:
expand_triton(batch_size,
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
max_num_tokens=MAX_SPEC_LEN)
expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens=MAX_SPEC_LEN)
else:
expand_pytorch(
expanded_x,
@@ -301,7 +306,7 @@ def sample_recovered_tokens(
num_draft_tokens: list[int],
cu_num_draft_tokens: torch.Tensor,
draft_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor],
draft_probs: torch.Tensor | None,
target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata,
device: torch.device,
@@ -316,9 +321,7 @@ def sample_recovered_tokens(
)
q.exponential_()
num_draft_tensor = torch.tensor(num_draft_tokens,
pin_memory=True).to(device,
non_blocking=True)
num_draft_tensor = torch.tensor(num_draft_tokens, pin_memory=True).to(device, non_blocking=True)
has_draft_mask = num_draft_tensor > 0
for i, generator in sampling_metadata.generators.items():
@@ -357,10 +360,10 @@ def sample_recovered_tokens(
def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
@@ -368,73 +371,56 @@ def rejection_greedy_sample_spec_len_1_pytorch(
accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
output_token_ids[:, 1])
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1])
def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
device, non_blocking=True)
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(device, non_blocking=True)
if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange(
num_tokens, device=device) - start_indices[token_req_ids]
token_positions = torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
# Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax)
mismatch_global = draft_token_ids != target_argmax
if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size,
dtype=torch.long,
device=device)
first_mismatch_pos_per_req = torch.zeros(batch_size, dtype=torch.long, device=device)
else:
# [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len),
-1,
dtype=torch.long,
device=device)
pos_matrix = torch.full((batch_size, max_spec_len), -1, dtype=torch.long, device=device)
pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len),
False,
dtype=torch.bool,
device=device)
mismatch_matrix = torch.full((batch_size, max_spec_len), False, dtype=torch.bool, device=device)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
max_spec_len * 2)
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
no_mismatch_mask]
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[no_mismatch_mask]
# Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[
global_idx[final_copy_mask]].to(output_token_ids.dtype)
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(output_token_ids.dtype)
# Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req
>= draft_tokens_per_req)
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows]
@@ -458,24 +444,24 @@ def rejection_random_sample_pytorch(
):
"""
This function implements the Speculative Decoding rejection sampling step.
Instead of looping through each request and each token (which causes high
Instead of looping through each request and each token (which causes high
overhead), it uses a fully vectorized approach:
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D
[batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle
variable-length sequences in the batch.
2. **Parallel Validation**: Calculates the acceptance condition
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
2. **Parallel Validation**: Calculates the acceptance condition
(target_prob / draft_prob >= uniform_sample) for ALL draft tokens
simultaneously across the entire batch.
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
subsequent tokens are ignored. Here, we simulate this by finding the
'first_reject_pos' using argmax on the rejection mask and creating a
3. **Short-circuit Simulation**: In the loop version, once a token is rejected,
subsequent tokens are ignored. Here, we simulate this by finding the
'first_reject_pos' using argmax on the rejection mask and creating a
'should_skip' mask for all indices after the first failure.
4. **Token Selection**: Uses 'torch.where' to select:
- Draft tokens (if accepted)
- Recovered tokens (at the point of first rejection)
- Bonus tokens (if all tokens in a sequence were accepted)
5. **Masking**: Ensures operations only apply to non-greedy requests and
5. **Masking**: Ensures operations only apply to non-greedy requests and
within valid sequence lengths.
"""
@@ -495,15 +481,12 @@ def rejection_random_sample_pytorch(
valid_mask = pos_indices < num_draft_per_batch[:, None]
global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp(
0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[
global_token_indices] # [batch_size, max_draft_len]
global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[global_token_indices] # [batch_size, max_draft_len]
if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to(
device, non_blocking=True).expand_as(draft_tokens)
draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
else:
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
@@ -518,24 +501,21 @@ def rejection_random_sample_pytorch(
uniform_token_probs = uniform_probs[global_token_indices]
recovered_tokens = recovered_token_ids[global_token_indices]
zero_threshold_cpu = torch.tensor([0.0],
pin_memory=True,
dtype=torch.float32)
zero_threshold_cpu = torch.tensor([0.0], pin_memory=True, dtype=torch.float32)
zero_threshold = zero_threshold_cpu.to(device, non_blocking=True)
acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs / draft_token_probs >= uniform_token_probs)
target_token_probs / draft_token_probs >= uniform_token_probs
)
first_rejection = (~acceptance_condition) & valid_mask
default_pos_cpu = torch.full([batch_size, 1],
max_draft_len,
pin_memory=True)
default_pos_cpu = torch.full([batch_size, 1], max_draft_len, pin_memory=True)
default_pos = default_pos_cpu.to(device, non_blocking=True)
first_reject_pos = torch.where(
first_rejection.any(dim=1, keepdim=True),
first_rejection.float().argmax(dim=1, keepdim=True), default_pos)
first_rejection.any(dim=1, keepdim=True), first_rejection.float().argmax(dim=1, keepdim=True), default_pos
)
pos_mask = pos_indices >= first_reject_pos
should_skip = pos_mask & valid_mask
@@ -543,16 +523,17 @@ def rejection_random_sample_pytorch(
non_greedy_mask = ~is_greedy
update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip)
first_reject_mask = (pos_indices == first_reject_pos
) & valid_mask & non_greedy_mask[:, None]
first_reject_mask = (pos_indices == first_reject_pos) & valid_mask & non_greedy_mask[:, None]
final_update_mask = update_mask | first_reject_mask
final_tokens = torch.where(
first_reject_mask, recovered_tokens,
torch.where(final_acceptance, draft_tokens,
output_token_ids[:, :max_draft_len]))
first_reject_mask,
recovered_tokens,
torch.where(final_acceptance, draft_tokens, output_token_ids[:, :max_draft_len]),
)
output_token_ids[:, :max_draft_len] = torch.where(
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len])
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]
)
no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch
should_add_bonus = non_greedy_mask & no_rejection
@@ -561,8 +542,7 @@ def rejection_random_sample_pytorch(
seq_len = output_token_ids.shape[1]
all_positions_cpu = torch.arange(seq_len, pin_memory=True)
all_positions = all_positions_cpu.to(
device, non_blocking=True)[None, :] # [1, seq_len]
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] # [1, seq_len]
batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1]
@@ -572,12 +552,11 @@ def rejection_random_sample_pytorch(
valid_bonus_pos = bonus_positions < (max_spec_len_device + 1)
final_bonus_mask = should_add_bonus & valid_bonus_pos
bonus_pos_match = (all_positions == batch_bonus_positions)
bonus_pos_match = all_positions == batch_bonus_positions
bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None]
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len)
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded,
output_token_ids)
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, output_token_ids)
def expand_pytorch(
@@ -589,17 +568,17 @@ def expand_pytorch(
MAX_NUM_TOKENS,
):
"""
This function broadcasts batch-level values (input_ptr) to token-level
positions (output_ptr) based on cumulative token offsets. It acts like
This function broadcasts batch-level values (input_ptr) to token-level
positions (output_ptr) based on cumulative token offsets. It acts like
a "scatter" or "repeat_interleave" operation but with custom logic:
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
[num_tokens, batch_size] that identifies which batch index each token
1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size
[num_tokens, batch_size] that identifies which batch index each token
belongs to by checking if the token index falls between cu_start and cu_end.
2. **Conditional Replacement**: Before expansion, it replaces specific values
2. **Conditional Replacement**: Before expansion, it replaces specific values
(e.g., padding or special markers) in the input to prepare the data.
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
sum that effectively "picks" the correct batch value for every token position
3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted
sum that effectively "picks" the correct batch value for every token position
simultaneously, avoiding a Python loop over the batch.
"""
device = cu_num_tokens_ptr.device
@@ -609,21 +588,16 @@ def expand_pytorch(
if batch_size == 0 or num_tokens == 0:
return
cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_tokens_ptr[:-1]
])
cu_start = torch.cat([torch.tensor([0], pin_memory=True).to(device, non_blocking=True), cu_num_tokens_ptr[:-1]])
cu_end = cu_num_tokens_ptr
token_indices = torch.arange(num_tokens,
device=device)[:, None] # [num_tokens, 1]
token_indices = torch.arange(num_tokens, device=device)[:, None] # [num_tokens, 1]
cu_start_exp = cu_start[None, :] # [1, batch_size]
cu_end_exp = cu_end[None, :] # [1, batch_size]
in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp)
replaced_input = torch.where(input_ptr == replace_from, replace_to,
input_ptr).float()
replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float()
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)
@@ -643,21 +617,21 @@ def sample_recovered_tokens_pytorch(
IS_NGRAM=False,
):
"""
When a draft token is rejected, we must sample a "recovered" token from
a modified distribution. This function calculates that distribution across
When a draft token is rejected, we must sample a "recovered" token from
a modified distribution. This function calculates that distribution across
the entire flattened batch.
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
determines which request in the batch each token belongs to. This is
1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it
determines which request in the batch each token belongs to. This is
necessary because 'q' (normalization factor) is stored per-request.
2. **Probability Adjustment**:
2. **Probability Adjustment**:
- If N-GRAM: It zeroes out the draft token's probability in the target.
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
- If Probabilistic: It calculates max(0, target_probs - draft_probs)
as per the standard speculative decoding algorithm.
3. **Normalization & Sampling**: It divides the adjusted probabilities
by the normalization distribution 'q'. To remain vectorized, it
3. **Normalization & Sampling**: It divides the adjusted probabilities
by the normalization distribution 'q'. To remain vectorized, it
broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab].
4. **Argmax Selection**: It selects the best recovery token for every
4. **Argmax Selection**: It selects the best recovery token for every
position in one pass using torch.argmax.
"""
device = output_token_ids.device
@@ -666,10 +640,12 @@ def sample_recovered_tokens_pytorch(
if num_tokens == 0:
return
cu_start = torch.cat([
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_draft_tokens[:-1],
])
cu_start = torch.cat(
[
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_draft_tokens[:-1],
]
)
cu_end = cu_num_draft_tokens
token_indices = torch.arange(num_tokens, device=device) # [num_tokens]
@@ -678,8 +654,7 @@ def sample_recovered_tokens_pytorch(
cu_start_expanded = cu_start[None, :] # [1, batch_size]
cu_end_expanded = cu_end[None, :] # [1, batch_size]
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (
token_indices_expanded < cu_end_expanded)
in_range_mask = (token_indices_expanded >= cu_start_expanded) & (token_indices_expanded < cu_end_expanded)
token_to_batch = torch.argmax(in_range_mask.int(), dim=1)
@@ -707,8 +682,7 @@ def sample_recovered_tokens_pytorch(
prob_over_q = prob / q_values_safe
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10,
prob_over_q)
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, prob_over_q)
recovered_ids = torch.argmax(prob_over_q, dim=1)
@@ -742,14 +716,12 @@ def rejection_random_sample_block_verify_pytorch(
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]
valid_mask = pos_indices < num_draft_per_batch
global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp(
0, draft_token_ids.shape[0] - 1)
global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[global_token_indices]
if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to(
device, non_blocking=True).expand_as(draft_tokens)
draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
else:
flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten()
@@ -772,27 +744,21 @@ def rejection_random_sample_block_verify_pytorch(
last_accept_pos = torch.where(
legal_mask.any(dim=-1, keepdim=True),
(max_spec_len -
legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
-1)
(max_spec_len - legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
-1,
)
non_greedy_mask = (~is_greedy)[:, None]
accept_mask = (pos_indices
<= last_accept_pos) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(
accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
accept_mask = (pos_indices <= last_accept_pos) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
reject_mask = (pos_indices
== last_accept_pos + 1) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(
reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
reject_mask = (pos_indices == last_accept_pos + 1) & valid_mask & non_greedy_mask
output_token_ids[:, :max_spec_len] = torch.where(reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask
all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True)
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :]
bonus_pos_match = (all_positions == num_draft_per_batch)
bonus_pos_match = all_positions == num_draft_per_batch
bonus_mask = bonus_mask & bonus_pos_match
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(
-1, max_spec_len + 1)
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded,
output_token_ids)
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, max_spec_len + 1)
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, output_token_ids)

View File

@@ -35,7 +35,6 @@ def random_sample(
class AscendSampler(Sampler):
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
# TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode)
@@ -62,7 +61,6 @@ class AscendSampler(Sampler):
class AscendTopKTopPSampler(TopKTopPSampler):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_top_k_top_p = apply_top_k_top_p
@@ -135,4 +133,9 @@ def _apply_top_k_top_p_ascendc(
return logits
return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p)
apply_top_k_top_p = _apply_top_k_top_p_ascendc if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] else _apply_top_k_top_p_pytorch
apply_top_k_top_p = (
_apply_top_k_top_p_ascendc
if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3]
else _apply_top_k_top_p_pytorch
)

View File

@@ -1,5 +1,3 @@
from typing import Optional, Union
import numpy as np
import torch
from vllm.distributed import get_dcp_group, get_pcp_group
@@ -8,17 +6,18 @@ from vllm.v1.utils import CpuGpuBuffer
class BlockTable:
def __init__(self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_sizes: Union[list[int], None] = None,
cp_kv_cache_interleave_size: int = 1,
num_speculative_tokens: int = 0):
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_sizes: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1,
num_speculative_tokens: int = 0,
):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
@@ -28,8 +27,7 @@ class BlockTable:
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_world_size > 1 else 0
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_world_size > 1 else 0
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
@@ -49,42 +47,37 @@ class BlockTable:
# Find the first kernel size that divides physical_block_size evenly
selected_kernel_size = None
for kernel_size in kernel_sizes:
if kernel_size > 0 \
and self.physical_block_size % kernel_size == 0:
if kernel_size > 0 and self.physical_block_size % kernel_size == 0:
selected_kernel_size = kernel_size
break
if selected_kernel_size is None:
raise ValueError(
f"None of the kernel sizes {kernel_sizes} can divide "
f"physical block size {self.physical_block_size} evenly")
f"physical block size {self.physical_block_size} evenly"
)
self.block_size = selected_kernel_size
self.logical_block_size = selected_kernel_size
self.blocks_per_phys_block = (self.physical_block_size //
self.logical_block_size)
self.blocks_per_phys_block = self.physical_block_size // self.logical_block_size
if self.blocks_per_phys_block > 1:
self.use_hybrid_blocks = True
else:
self.use_hybrid_blocks = False
if self.use_hybrid_blocks:
logical_table_size = (max_num_blocks_per_req *
self.blocks_per_phys_block)
logical_table_size = max_num_blocks_per_req * self.blocks_per_phys_block
else:
logical_table_size = max_num_blocks_per_req
duplicate_size = 1
if self.pcp_world_size * self.dcp_world_size > 1:
duplicate_size += num_speculative_tokens
self.block_table = self._make_buffer(max_num_reqs * duplicate_size,
logical_table_size,
dtype=torch.int32)
self.block_table = self._make_buffer(max_num_reqs * duplicate_size, logical_table_size, dtype=torch.int32)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens +
2 * self.pcp_world_size * self.max_num_reqs,
dtype=torch.int32)
self.max_num_batched_tokens + 2 * self.pcp_world_size * self.max_num_reqs, dtype=torch.int32
)
self.kernel_sizes = kernel_sizes
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
@@ -103,7 +96,7 @@ class BlockTable:
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] += num_blocks
def add_row(self, block_ids: list[int], row_idx: int) -> None:
@@ -112,8 +105,7 @@ class BlockTable:
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table.np[tgt, :num_blocks] = self.block_table.np[
src, :num_blocks]
self.block_table.np[tgt, :num_blocks] = self.block_table.np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
@@ -124,8 +116,7 @@ class BlockTable:
self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]]
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
@@ -150,27 +141,30 @@ class BlockTable:
# (always needed with unified tensor)
# Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this
block_table_indices = (req_indices * self.max_num_blocks_per_req *
self.blocks_per_phys_block +
logical_block_idx)
block_table_indices = (
req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
mask = (virtual_block_offsets // self.cp_kv_cache_interleave_size %
(self.dcp_world_size *
self.pcp_world_size) == self.current_rank)
mask = (
virtual_block_offsets // self.cp_kv_cache_interleave_size % (self.dcp_world_size * self.pcp_world_size)
== self.current_rank
)
# Calculate local block_offsets
block_offsets = virtual_block_offsets \
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) \
* self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size
block_offsets = (
virtual_block_offsets
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
mask, slot_mapping, -1)
self.slot_mapping.np[: req_indices.shape[0]] = np.where(mask, slot_mapping, -1)
else:
assert self.kernel_sizes is not None
if self.block_size == self.kernel_sizes[0]:
@@ -183,15 +177,12 @@ class BlockTable:
# Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this
block_table_indices = (
req_indices * self.max_num_blocks_per_req *
self.blocks_per_phys_block + logical_block_idx)
req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
)
block_numbers = self.block_table.np.ravel(
)[block_table_indices]
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[:req_indices.shape[0]])
np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping.np[: req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)
@@ -203,8 +194,7 @@ class BlockTable:
self.block_table.fill_(0)
self.block_table.cpu.fill_(0)
def _convert_physical_to_logical_blocks(
self, physical_blocks: np.ndarray) -> np.ndarray:
def _convert_physical_to_logical_blocks(self, physical_blocks: np.ndarray) -> np.ndarray:
"""Convert physical block IDs to logical block IDs."""
if not self.use_hybrid_blocks:
return physical_blocks
@@ -217,8 +207,7 @@ class BlockTable:
# [1*split_ratio, 1*split_ratio+1, ...]
# But we need to account for the fact that block 0 is special
base_logical = phys_block * self.blocks_per_phys_block
logical_blocks.extend(
range(base_logical, base_logical + self.blocks_per_phys_block))
logical_blocks.extend(range(base_logical, base_logical + self.blocks_per_phys_block))
return np.array(logical_blocks, dtype=np.int32)
@@ -234,27 +223,25 @@ class BlockTable:
"""Returns the numpy array of the block table."""
return self.block_table.np
def _make_buffer(self, *size: int | torch.SymInt,
dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
def _make_buffer(self, *size: int | torch.SymInt, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
kernel_sizes: Optional[list[list[int]]] = None,
cp_kv_cache_interleave_size: int = 1) -> None:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
kernel_sizes: list[list[int]] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
@@ -274,24 +261,26 @@ class MultiGroupBlockTable:
kernel_sizes = kernel_sizes * len(block_sizes)
elif len(kernel_sizes) != len(block_sizes):
raise ValueError(
f"kernel_sizes length ({len(kernel_sizes)}) must match "
f"block_sizes length ({len(block_sizes)})")
f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})"
)
# Use zip to pair block_sizes with kernel_sizes one-to-one
self.block_tables = [
BlockTable(
block_size, max_num_reqs,
max(
cdiv(max_model_len,
block_size * dcp_world_size * pcp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device, kernel_size_list,
cp_kv_cache_interleave_size, num_speculative_tokens)
block_size,
max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens),
max_num_batched_tokens,
pin_memory,
device,
kernel_size_list,
cp_kv_cache_interleave_size,
num_speculative_tokens,
)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
]
def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
@@ -307,8 +296,7 @@ class MultiGroupBlockTable:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)