[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #7) (#6023)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|` vllm_ascend/quantization/compressed_tensors/compressed_tensors.py`|
|` vllm_ascend/quantization/quant_config.py`|
|` vllm_ascend/quantization/utils.py`|
|` vllm_ascend/quantization/w4a16.py`|
|` vllm_ascend/quantization/w4a4_flatquant_dynamic.py`|
|` vllm_ascend/quantization/w4a8_dynamic.py`|
|` vllm_ascend/quantization/w8a16.py`|
|` vllm_ascend/quantization/w8a8.py`|
|` vllm_ascend/quantization/w8a8_dynamic.py`|
|` vllm_ascend/quantization/w8a8_pdmix.py`|
|` vllm_ascend/quantization/w8a8mxfp8.py`|
|` vllm_ascend/sample/rejection_sampler.py`|
|` vllm_ascend/sample/sampler.py`|
|` vllm_ascend/worker/block_table.py`|

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-06 14:56:53 +08:00
committed by GitHub
parent d0bc16859c
commit 99aedaff63
20 changed files with 997 additions and 1307 deletions

View File

@@ -51,10 +51,7 @@ line-length = 120
# Folder to be modified # Folder to be modified
exclude = [ exclude = [
"tests/**", "tests/**",
# (7)
"vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py",
"vllm_ascend/worker/block_table.py",
# (8) # (8)
"vllm_ascend/ops/__init__.py", "vllm_ascend/ops/__init__.py",
"vllm_ascend/ops/activation.py", "vllm_ascend/ops/activation.py",
@@ -66,6 +63,7 @@ exclude = [
"vllm_ascend/ops/vocab_parallel_embedding.py", "vllm_ascend/ops/vocab_parallel_embedding.py",
"vllm_ascend/ops/weight_prefetch.py", "vllm_ascend/ops/weight_prefetch.py",
"vllm_ascend/spec_decode/**", "vllm_ascend/spec_decode/**",
# (10) # (10)
"vllm_ascend/ops/*linear*.py", "vllm_ascend/ops/*linear*.py",
"vllm_ascend/worker/worker.py", "vllm_ascend/worker/worker.py",
@@ -76,6 +74,7 @@ exclude = [
"vllm_ascend/worker/v2/**", "vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py", "vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py", "vllm_ascend/ops/rotary_embedding.py",
# (11) # (11)
"vllm_ascend/ops/fused_moe/**", "vllm_ascend/ops/fused_moe/**",
] ]

View File

@@ -29,6 +29,7 @@ Public API:
# LLM-Compressor (compressed_tensors) quantization config # LLM-Compressor (compressed_tensors) quantization config
from .compressed_tensors_config import AscendCompressedTensorsConfig from .compressed_tensors_config import AscendCompressedTensorsConfig
# ModelSlim quantization config # ModelSlim quantization config
from .modelslim_config import AscendModelSlimConfig from .modelslim_config import AscendModelSlimConfig

View File

@@ -17,23 +17,20 @@
# #
"""LLM-Compressor (compressed_tensors) quantization configuration for Ascend.""" """LLM-Compressor (compressed_tensors) quantization configuration for Ascend."""
from typing import Any, Optional, Union, cast from typing import Any, Optional, cast
import torch import torch
from compressed_tensors.quantization import (QuantizationArgs, from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, QuantizationType
QuantizationStrategy,
QuantizationType)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS, register_quantization_config
from vllm.model_executor.layers.quantization import ( from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase
QUANTIZATION_METHODS, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format, find_matched_target,
should_ignore_layer) is_activation_quantization_format,
should_ignore_layer,
)
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
@@ -51,8 +48,7 @@ def _remove_quantization_method():
_remove_quantization_method() _remove_quantization_method()
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, "QuantizationArgs"] | None]
"QuantizationArgs"]]]
@register_quantization_config(COMPRESSED_TENSORS_METHOD) @register_quantization_config(COMPRESSED_TENSORS_METHOD)
@@ -68,7 +64,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
target_scheme_map: dict[str, Any], target_scheme_map: dict[str, Any],
ignore: list[str], ignore: list[str],
quant_format: str, quant_format: str,
config: Optional[dict[str, Any]] = None, config: dict[str, Any] | None = None,
): ):
super().__init__() super().__init__()
self.ignore = ignore self.ignore = ignore
@@ -86,8 +82,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
raise NotImplementedError( raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.')
"Ascend hardware dose not support \"get_min_capability\" feature.")
@classmethod @classmethod
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
@@ -100,18 +95,15 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
targeting 'Linear' needs to also match targeting 'Linear' needs to also match
FusedMoE modules. FusedMoE modules.
""" """
if ("Linear" not in self.target_scheme_map if "Linear" not in self.target_scheme_map or "FusedMoE" in self.target_scheme_map:
or "FusedMoE" in self.target_scheme_map):
return return
self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"] self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"]
@classmethod @classmethod
def from_config(cls, config: dict[str, def from_config(cls, config: dict[str, Any]) -> "AscendCompressedTensorsConfig":
Any]) -> "AscendCompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", [])) ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format")) quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config( target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
config=config)
return cls( return cls(
target_scheme_map=target_scheme_map, target_scheme_map=target_scheme_map,
@@ -121,8 +113,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
) )
@classmethod @classmethod
def _quantization_scheme_map_from_config( def _quantization_scheme_map_from_config(cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""Build target scheme map from config. """Build target scheme map from config.
:param config: The `quantization_config` dictionary from config.json :param config: The `quantization_config` dictionary from config.json
@@ -138,24 +129,22 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
targets = quant_config.get("targets") targets = quant_config.get("targets")
for target in targets: for target in targets:
target_scheme_map[target] = {} target_scheme_map[target] = {}
target_scheme_map[target][ target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(quant_config.get("weights"))
"weights"] = QuantizationArgs.model_validate(
quant_config.get("weights"))
target_scheme_map[target]["input_activations"] = None target_scheme_map[target]["input_activations"] = None
target_scheme_map[target]["format"] = quant_config.get( target_scheme_map[target]["format"] = quant_config.get("format")
"format")
format = target_scheme_map[target].get("format") format = target_scheme_map[target].get("format")
# If no per-config format defined, use global format in config # If no per-config format defined, use global format in config
act_quant_format = ( act_quant_format = (
is_activation_quantization_format(format) is_activation_quantization_format(format)
if format is not None else if format is not None
is_activation_quantization_format(quant_format)) else is_activation_quantization_format(quant_format)
)
input_activations = quant_config.get("input_activations") input_activations = quant_config.get("input_activations")
if act_quant_format and input_activations is not None: if act_quant_format and input_activations is not None:
target_scheme_map[target]["input_activations"] = ( target_scheme_map[target]["input_activations"] = QuantizationArgs.model_validate(
QuantizationArgs.model_validate( quant_config.get("input_activations")
quant_config.get("input_activations"))) )
return target_scheme_map return target_scheme_map
def get_quant_method( def get_quant_method(
@@ -168,8 +157,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
# Get the scheme for this layer # Get the scheme for this layer
linear_scheme = self._get_linear_scheme(layer=layer, linear_scheme = self._get_linear_scheme(layer=layer, layer_name=prefix)
layer_name=prefix)
# Return unquantized method if no scheme found # Return unquantized method if no scheme found
if linear_scheme is None: if linear_scheme is None:
@@ -177,14 +165,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Store scheme on layer for reference (optional, for debugging) # Store scheme on layer for reference (optional, for debugging)
layer.scheme = linear_scheme layer.scheme = linear_scheme
logger.info_once( logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!")
"Using the vLLM Ascend llmcompressor Quantization now!")
return AscendLinearMethod(linear_scheme) return AscendLinearMethod(linear_scheme)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
# Delayed import to avoid circular import # Delayed import to avoid circular import
from vllm_ascend.ops.fused_moe.fused_moe import \ from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
AscendUnquantizedFusedMoEMethod
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
layer_name = prefix + ".0.gate_proj" layer_name = prefix + ".0.gate_proj"
@@ -197,24 +183,19 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Store scheme on layer for reference (optional, for debugging) # Store scheme on layer for reference (optional, for debugging)
layer.scheme = moe_scheme layer.scheme = moe_scheme
logger.info_once( logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!")
"Using the vLLM Ascend llmcompressor Quantization now!")
return AscendFusedMoEMethod(moe_scheme, layer.moe_config) return AscendFusedMoEMethod(moe_scheme, layer.moe_config)
return None return None
def _get_linear_scheme( def _get_linear_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendLinearScheme | None:
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> Optional[AscendLinearScheme]:
"""Get the linear quantization scheme for a layer. """Get the linear quantization scheme for a layer.
Returns: Returns:
An AscendLinearScheme instance, or None if the layer An AscendLinearScheme instance, or None if the layer
should use unquantized method. should use unquantized method.
""" """
weight_quant, input_quant, format = self._get_quant_args( weight_quant, input_quant, format = self._get_quant_args(layer, layer_name)
layer, layer_name)
if weight_quant is None: if weight_quant is None:
return None return None
@@ -226,10 +207,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
) )
return cast(AscendLinearScheme, scheme) return cast(AscendLinearScheme, scheme)
def _get_moe_scheme( def _get_moe_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendMoEScheme | None:
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> Optional[AscendMoEScheme]:
"""Get the MoE quantization scheme for a layer. """Get the MoE quantization scheme for a layer.
Returns: Returns:
@@ -239,8 +217,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Add FusedMoE to target scheme map if needed # Add FusedMoE to target scheme map if needed
self._add_fused_moe_to_target_scheme_map() self._add_fused_moe_to_target_scheme_map()
weight_quant, input_quant, format = self._get_quant_args( weight_quant, input_quant, format = self._get_quant_args(layer, layer_name)
layer, layer_name)
if weight_quant is None: if weight_quant is None:
return None return None
@@ -253,11 +230,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
return cast(AscendMoEScheme, scheme) return cast(AscendMoEScheme, scheme)
def _get_quant_args( def _get_quant_args(
self, self, layer: torch.nn.Module, layer_name: str | None = None
layer: torch.nn.Module, ) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"], str | None]:
layer_name: Optional[str] = None
) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"],
Optional[str]]:
"""Extract quantization arguments for a layer. """Extract quantization arguments for a layer.
compressed-tensors supports non uniform in the following way: compressed-tensors supports non uniform in the following way:
@@ -284,16 +258,16 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
format = scheme_dict.get("format") format = scheme_dict.get("format")
if weight_quant is None: if weight_quant is None:
logger.warning_once("Acceleration for non-quantized schemes is " logger.warning_once(
"not supported by Compressed Tensors. " "Acceleration for non-quantized schemes is "
"Falling back to UnquantizedLinearMethod") "not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod"
)
return weight_quant, input_quant, format return weight_quant, input_quant, format
def get_scheme_dict( def get_scheme_dict(
self, self, layer: torch.nn.Module, layer_name: str | None = None
layer: torch.nn.Module,
layer_name: str | None = None
) -> dict[str, QuantizationArgs | str | None] | None: ) -> dict[str, QuantizationArgs | str | None] | None:
""" """
Extract the QuantizationArgs for a given layer. Extract the QuantizationArgs for a given layer.
@@ -305,9 +279,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
"format": str | None "format": str | None
} | None } | None
""" """
if should_ignore_layer(layer_name, if should_ignore_layer(layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping):
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return None return None
if self.target_scheme_map: if self.target_scheme_map:
@@ -328,9 +300,9 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
self, self,
weight_quant: "QuantizationArgs", weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"], input_quant: Optional["QuantizationArgs"],
format: Optional[str], format: str | None,
layer_type: str, layer_type: str,
) -> Union[AscendLinearScheme, AscendMoEScheme]: ) -> AscendLinearScheme | AscendMoEScheme:
"""Create the appropriate Ascend scheme based on quantization args and layer type. """Create the appropriate Ascend scheme based on quantization args and layer type.
Args: Args:
@@ -352,7 +324,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
if scheme_cls is None: if scheme_cls is None:
raise NotImplementedError( raise NotImplementedError(
f"No compressed-tensors compatible scheme was found for " f"No compressed-tensors compatible scheme was found for "
f"quant_type={quant_type}, layer_type={layer_type}.") f"quant_type={quant_type}, layer_type={layer_type}."
)
return scheme_cls() return scheme_cls()
@@ -360,7 +333,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
self, self,
weight_quant: "QuantizationArgs", weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"], input_quant: Optional["QuantizationArgs"],
format: Optional[str], format: str | None,
) -> str: ) -> str:
"""Detect the quantization type from quantization arguments. """Detect the quantization type from quantization arguments.
@@ -389,16 +362,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
if self._is_w4a16(weight_quant, input_quant): if self._is_w4a16(weight_quant, input_quant):
return "W4A16" return "W4A16"
raise NotImplementedError( raise NotImplementedError("No compressed-tensors compatible quantization type was found.")
"No compressed-tensors compatible quantization type was found.")
def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool:
input_quant: "QuantizationArgs") -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = ( weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value
weight_quant.strategy == QuantizationStrategy.CHANNEL.value) is_tensor = weight_strategy and input_quant.strategy == QuantizationStrategy.TENSOR.value
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_static = not weight_quant.dynamic and not input_quant.dynamic is_static = not weight_quant.dynamic and not input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric is_symmetric = weight_quant.symmetric and input_quant.symmetric
@@ -406,13 +375,10 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_8_bits and is_tensor and is_symmetric and is_static return is_8_bits and is_tensor and is_symmetric and is_static
def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs", def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool:
input_quant: "QuantizationArgs") -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = ( weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value
weight_quant.strategy == QuantizationStrategy.CHANNEL.value) is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric is_symmetric = weight_quant.symmetric and input_quant.symmetric
@@ -420,14 +386,13 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_8_bits and is_token and is_symmetric and is_dynamic return is_8_bits and is_token and is_symmetric and is_dynamic
def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs, def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs) -> bool:
input_quant: QuantizationArgs) -> bool:
is_4_bits = weight_quant.num_bits == 4 is_4_bits = weight_quant.num_bits == 4
is_8_bits = input_quant.num_bits == 8 is_8_bits = input_quant.num_bits == 8
weight_strategy = ( weight_strategy = (weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or (weight_quant.strategy == QuantizationStrategy.GROUP.value) weight_quant.strategy == QuantizationStrategy.GROUP.value
is_token = (weight_strategy and input_quant.strategy )
== QuantizationStrategy.TOKEN.value) is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
is_dynamic = not weight_quant.dynamic and input_quant.dynamic is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric is_symmetric = weight_quant.symmetric and input_quant.symmetric
@@ -444,8 +409,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_4_bits and is_8_bits and is_token and is_symmetric and is_dynamic return is_4_bits and is_8_bits and is_token and is_symmetric and is_dynamic
def _is_w4a16(self, weight_quant: "QuantizationArgs", def _is_w4a16(self, weight_quant: "QuantizationArgs", input_quant: Optional["QuantizationArgs"]) -> bool:
input_quant: Optional["QuantizationArgs"]) -> bool:
# Confirm weights quantized. # Confirm weights quantized.
if weight_quant is None: if weight_quant is None:
return False return False
@@ -456,12 +420,11 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
input_quant_none = input_quant is None input_quant_none = input_quant is None
is_4_bits = weight_quant.num_bits == 4 is_4_bits = weight_quant.num_bits == 4
is_group = (weight_quant.strategy == QuantizationStrategy.GROUP.value) is_group = weight_quant.strategy == QuantizationStrategy.GROUP.value
is_static = not weight_quant.dynamic is_static = not weight_quant.dynamic
return input_quant_none and is_4_bits and is_group and is_static return input_quant_none and is_4_bits and is_group and is_static
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict( self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map)
self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)

View File

@@ -16,27 +16,22 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from typing import Callable, List, Optional from collections.abc import Callable
import torch import torch
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase, FusedMoeWeightScaleSupported
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearMethodBase, RowParallelLinear
RowParallelLinear)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, from vllm_ascend.distributed.parallel_state import get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable
from .methods import (AscendAttentionScheme, AscendLinearScheme, from .methods import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, is_mx_quant_type
AscendMoEScheme, is_mx_quant_type)
class AscendLinearMethod(LinearMethodBase): class AscendLinearMethod(LinearMethodBase):
@@ -56,7 +51,7 @@ class AscendLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], output_partition_sizes: list[int],
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
@@ -65,9 +60,7 @@ class AscendLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(input_size_per_partition, weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, params_dtype)
output_size_per_partition,
params_dtype)
# Extract packing information (if present) # Extract packing information (if present)
packed_dim = weight_dict.pop("_packed_dim", None) packed_dim = weight_dict.pop("_packed_dim", None)
@@ -79,25 +72,20 @@ class AscendLinearMethod(LinearMethodBase):
# Set packing attributes if the weight is packed # Set packing attributes if the weight is packed
if packed_dim is not None and packed_factor is not None: if packed_dim is not None and packed_factor is not None:
set_weight_attrs(param, { set_weight_attrs(param, {"packed_dim": packed_dim, "packed_factor": packed_factor})
"packed_dim": packed_dim,
"packed_factor": packed_factor
})
layer.register_parameter(weight_name, param) layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs) set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items(): for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param, param = PerTensorScaleParameter(data=pertensor_param, weight_loader=weight_loader)
weight_loader=weight_loader)
# disable warning # disable warning
param.ignore_warning = True param.ignore_warning = True
layer.register_parameter(pertensor_name, param) layer.register_parameter(pertensor_name, param)
param.weight_loader = extra_weight_attrs.get("weight_loader") param.weight_loader = extra_weight_attrs.get("weight_loader")
perchannel_dict = self.quant_method.get_perchannel_param( perchannel_dict = self.quant_method.get_perchannel_param(output_size_per_partition, params_dtype)
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items(): for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False) param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0}) set_weight_attrs(param, {"output_dim": 0})
@@ -107,22 +95,22 @@ class AscendLinearMethod(LinearMethodBase):
# NOTE: In w4a8 quantization implementation, # NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj scale_bias shape is [output_size, 16], # for down_proj and o_proj scale_bias shape is [output_size, 16],
# others are [output_size, 1] # others are [output_size, 1]
layer_type = "row" if isinstance(layer, layer_type = "row" if isinstance(layer, RowParallelLinear) else "others"
RowParallelLinear) else "others"
pergroup_dict = self.quant_method.get_pergroup_param( pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition, input_size_per_partition, output_size_per_partition, params_dtype, layer_type=layer_type
output_size_per_partition, )
params_dtype,
layer_type=layer_type)
for pergroup_name, pergroup_param in pergroup_dict.items(): for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False) param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0}) set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param) layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs) set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \ if (
or is_mx_quant_type(self.quant_method): "weight_scale_second" in pergroup_name
setattr(param, "input_dim", 1) or "weight_offset_second" in pergroup_name
or is_mx_quant_type(self.quant_method)
):
param.input_dim = 1
param.input_dim = 1 param.input_dim = 1
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
@@ -133,17 +121,15 @@ class AscendLinearMethod(LinearMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(layer, RowParallelLinear): if isinstance(layer, RowParallelLinear):
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable(): if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
tp_rank = get_otp_group().rank_in_group tp_rank = get_otp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
tp_rank = get_mlp_tp_group().rank_in_group tp_rank = get_mlp_tp_group().rank_in_group
elif (layer.prefix.find("o_proj") != -1 or elif (layer.prefix.find("o_proj") != -1 or layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
layer.prefix.find("out_proj") != -1) and flashcomm2_enable(): if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1:
if get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size == 1:
tp_rank = 0 tp_rank = 0
else: else:
tp_rank = get_flashcomm2_otp_group().rank_in_group tp_rank = get_flashcomm2_otp_group().rank_in_group
@@ -175,11 +161,19 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.quant_method.process_weights_after_loading(layer) self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor, def apply(
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, self,
attn_type, scale, output) -> torch.Tensor: layer: torch.nn.Module,
return self.quant_method.apply(layer, query, key, value, kv_cache, query: torch.Tensor,
attn_metadata, attn_type, scale, output) key: torch.Tensor,
value: torch.Tensor,
kv_cache,
attn_metadata,
attn_type,
scale,
output,
) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache, attn_metadata, attn_type, scale, output)
class AscendFusedMoEMethod(FusedMoEMethodBase): class AscendFusedMoEMethod(FusedMoEMethodBase):
@@ -192,8 +186,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
moe_config: The FusedMoE configuration. moe_config: The FusedMoE configuration.
""" """
def __init__(self, scheme: AscendMoEScheme, def __init__(self, scheme: AscendMoEScheme, moe_config: FusedMoEConfig) -> None:
moe_config: FusedMoEConfig) -> None:
super().__init__(moe_config) super().__init__(moe_config)
self.quant_method = scheme self.quant_method = scheme
@@ -207,30 +200,28 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
weight_param = self.quant_method.get_weight( weight_param = self.quant_method.get_weight(
num_experts, intermediate_size_per_partition, hidden_size, num_experts, intermediate_size_per_partition, hidden_size, params_dtype
params_dtype) )
for param_key, param_value in weight_param.items(): for param_key, param_value in weight_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False) param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param) layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs) set_weight_attrs(param, extra_weight_attrs)
extra_weight_attrs.update( extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) per_group_param = (
per_group_param = [ ["weight_scale_second", "weight_offset_second", "scale_bias"] + ["weight_scale", "weight_offset"]
"weight_scale_second", "weight_offset_second", "scale_bias" if hasattr(self.quant_method, "group_size") and self.quant_method.group_size > 0
] + ["weight_scale", "weight_offset"] if hasattr( else []
self.quant_method, )
"group_size") and self.quant_method.group_size > 0 else []
dynamic_quant_param = self.quant_method.get_dynamic_quant_param( dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size, num_experts, intermediate_size_per_partition, hidden_size, params_dtype
params_dtype) )
for param_key, param_value in dynamic_quant_param.items(): for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False) param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param) layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs) set_weight_attrs(param, extra_weight_attrs)
if any(fields in param_key for fields in per_group_param): if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method", param.quant_method = FusedMoeWeightScaleSupported.GROUP.value
FusedMoeWeightScaleSupported.GROUP.value)
def apply( def apply(
self, self,
@@ -241,25 +232,40 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
topk_group: Optional[int] = None, topk_group: int | None = None,
num_expert_group: Optional[int] = None, num_expert_group: int | None = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True, is_prefill: bool = True,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num=0, global_redundant_expert_num=0,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
return self.quant_method.apply( return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk, layer,
global_num_experts, expert_map, topk_group, num_expert_group, x,
custom_routing_function, scoring_func, routed_scaling_factor, router_logits,
e_score_correction_bias, is_prefill, enable_force_load_balance, top_k,
log2phy, global_redundant_expert_num, **kwargs) renormalize,
use_grouped_topk,
global_num_experts,
expert_map,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
routed_scaling_factor,
e_score_correction_bias,
is_prefill,
enable_force_load_balance,
log2phy,
global_redundant_expert_num,
**kwargs,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"): if hasattr(self.quant_method, "process_weights_after_loading"):

View File

@@ -30,28 +30,26 @@ Usage:
from typing import Any from typing import Any
# Import base classes # Import base classes
from .base import (AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType
QuantType)
# Import registry functions # Import registry functions
from .registry import get_scheme_class, register_scheme from .registry import get_scheme_class, register_scheme
# Import all scheme classes for external access # Import all scheme classes for external access
from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
from .w4a8 import (AscendW4A8DynamicFusedMoEMethod, from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod
AscendW4A8DynamicLinearMethod)
from .w4a16 import AscendW4A16FusedMoEMethod from .w4a16 import AscendW4A16FusedMoEMethod
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, from .w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod
AscendW8A8DynamicLinearMethod)
from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, from .w8a8_pdmix import AscendW8A8PDMixFusedMoeMethod, AscendW8A8PDMixLinearMethod
AscendW8A8PDMixLinearMethod)
from .w8a8_static import AscendW8A8LinearMethod from .w8a8_static import AscendW8A8LinearMethod
from .w8a16 import AscendW8A16LinearMethod from .w8a16 import AscendW8A16LinearMethod
def is_mx_quant_type(instance: Any) -> bool: def is_mx_quant_type(instance: Any) -> bool:
"""Checks if the quantization method is a microscaling (MX) type.""" """Checks if the quantization method is a microscaling (MX) type."""
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, ) MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod,)
return isinstance(instance, MX_QUANT_TYPES) return isinstance(instance, MX_QUANT_TYPES)

View File

@@ -17,14 +17,16 @@
"""Abstract base classes for Ascend quantization schemes.""" """Abstract base classes for Ascend quantization schemes."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, Optional from typing import Any
import torch import torch
class QuantType(Enum): class QuantType(Enum):
"""Quantization type enum for MoE schemes.""" """Quantization type enum for MoE schemes."""
NONE = 0 NONE = 0
W8A8 = 1 W8A8 = 1
W4A8 = 2 W4A8 = 2
@@ -39,8 +41,7 @@ class AscendLinearScheme(ABC):
""" """
@abstractmethod @abstractmethod
def get_weight(self, input_size: int, output_size: int, def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return weight tensor specifications. """Return weight tensor specifications.
Args: Args:
@@ -54,7 +55,7 @@ class AscendLinearScheme(ABC):
""" """
... ...
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
"""Return per-tensor parameter specifications (e.g., input_scale). """Return per-tensor parameter specifications (e.g., input_scale).
Args: Args:
@@ -65,8 +66,7 @@ class AscendLinearScheme(ABC):
""" """
return {} return {}
def get_perchannel_param(self, output_size: int, def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return per-channel parameter specifications (e.g., weight_scale). """Return per-channel parameter specifications (e.g., weight_scale).
Args: Args:
@@ -78,11 +78,9 @@ class AscendLinearScheme(ABC):
""" """
return {} return {}
def get_pergroup_param(self, def get_pergroup_param(
input_size: int, self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
output_size: int, ) -> dict[str, Any]:
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
"""Return per-group parameter specifications. """Return per-group parameter specifications.
Args: Args:
@@ -97,11 +95,9 @@ class AscendLinearScheme(ABC):
return {} return {}
@abstractmethod @abstractmethod
def apply(self, def apply(
layer: torch.nn.Module, self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, tp_rank: int | None = 0
x: torch.Tensor, ) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0) -> torch.Tensor:
"""Forward computation. """Forward computation.
Args: Args:
@@ -121,7 +117,7 @@ class AscendLinearScheme(ABC):
Args: Args:
layer: The linear layer module. layer: The linear layer module.
""" """
pass return
class AscendAttentionScheme(ABC): class AscendAttentionScheme(ABC):
@@ -137,7 +133,7 @@ class AscendAttentionScheme(ABC):
Args: Args:
layer: The attention layer module. layer: The attention layer module.
""" """
pass return
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing for attention layer. """Post-loading weight processing for attention layer.
@@ -145,12 +141,21 @@ class AscendAttentionScheme(ABC):
Args: Args:
layer: The attention layer module. layer: The attention layer module.
""" """
pass return
@abstractmethod @abstractmethod
def apply(self, layer: torch.nn.Module, query: torch.Tensor, def apply(
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, self,
attn_type, scale, output) -> torch.Tensor: layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache,
attn_metadata,
attn_type,
scale,
output,
) -> torch.Tensor:
"""Forward computation for attention layer. """Forward computation for attention layer.
Args: Args:
@@ -185,9 +190,9 @@ class AscendMoEScheme(ABC):
quant_type: QuantType = QuantType.NONE quant_type: QuantType = QuantType.NONE
@abstractmethod @abstractmethod
def get_weight(self, num_experts: int, def get_weight(
intermediate_size_per_partition: int, hidden_sizes: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
params_dtype: torch.dtype) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Return weight tensor specifications for MoE layer. """Return weight tensor specifications for MoE layer.
Args: Args:
@@ -202,10 +207,9 @@ class AscendMoEScheme(ABC):
... ...
@abstractmethod @abstractmethod
def get_dynamic_quant_param(self, num_experts: int, def get_dynamic_quant_param(
intermediate_size_per_partition: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
hidden_sizes: int, ) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return dynamic quantization parameters for MoE layer. """Return dynamic quantization parameters for MoE layer.
Args: Args:
@@ -229,16 +233,16 @@ class AscendMoEScheme(ABC):
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
topk_group: Optional[int] = None, topk_group: int | None = None,
num_expert_group: Optional[int] = None, num_expert_group: int | None = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True, is_prefill: bool = True,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
@@ -276,4 +280,4 @@ class AscendMoEScheme(ABC):
Args: Args:
layer: The MoE layer module. layer: The MoE layer module.
""" """
pass return

View File

@@ -15,10 +15,10 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Dict, Optional, Tuple, Type from typing import Any
# Registry: maps (quant_type, layer_type) -> SchemeClass # Registry: maps (quant_type, layer_type) -> SchemeClass
_SCHEME_REGISTRY: Dict[Tuple[str, str], Type[Any]] = {} _SCHEME_REGISTRY: dict[tuple[str, str], type[Any]] = {}
def register_scheme(quant_type: str, layer_type: str): def register_scheme(quant_type: str, layer_type: str):
@@ -37,19 +37,19 @@ def register_scheme(quant_type: str, layer_type: str):
... ...
""" """
def decorator(cls: Type[Any]) -> Type[Any]: def decorator(cls: type[Any]) -> type[Any]:
key = (quant_type, layer_type) key = (quant_type, layer_type)
if key in _SCHEME_REGISTRY: if key in _SCHEME_REGISTRY:
raise ValueError( raise ValueError(
f"Scheme already registered for {quant_type}/{layer_type}: " f"Scheme already registered for {quant_type}/{layer_type}: {_SCHEME_REGISTRY[key].__name__}"
f"{_SCHEME_REGISTRY[key].__name__}") )
_SCHEME_REGISTRY[key] = cls _SCHEME_REGISTRY[key] = cls
return cls return cls
return decorator return decorator
def get_scheme_class(quant_type: str, layer_type: str) -> Optional[Type[Any]]: def get_scheme_class(quant_type: str, layer_type: str) -> type[Any] | None:
"""Get scheme class for given quant_type and layer_type. """Get scheme class for given quant_type and layer_type.
Args: Args:

View File

@@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Callable, Dict, Optional from collections.abc import Callable
from typing import Any
import torch import torch
import torch_npu import torch_npu
@@ -56,8 +57,7 @@ def unpack_from_int32(
dtype=torch.int32, dtype=torch.int32,
) )
for i in range(pack_factor): for i in range(pack_factor):
unpacked_weight[:, i::pack_factor] = (weight >> unpacked_weight[:, i::pack_factor] = (weight >> (num_bits * i)) & mask
(num_bits * i)) & mask
original_row_size = int(shape[1]) original_row_size = int(shape[1])
unpacked_weight = unpacked_weight[:, :original_row_size] unpacked_weight = unpacked_weight[:, :original_row_size]
else: else:
@@ -67,8 +67,7 @@ def unpack_from_int32(
dtype=torch.int32, dtype=torch.int32,
) )
for i in range(pack_factor): for i in range(pack_factor):
unpacked_weight[i::pack_factor, :] = (weight >> unpacked_weight[i::pack_factor, :] = (weight >> (num_bits * i)) & mask
(num_bits * i)) & mask
original_row_size = int(shape[0]) original_row_size = int(shape[0])
unpacked_weight = unpacked_weight[:original_row_size, :] unpacked_weight = unpacked_weight[:original_row_size, :]
@@ -84,22 +83,17 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
:param weight: The 3D tensor to pack, must be int8 or int32 dtype :param weight: The 3D tensor to pack, must be int8 or int32 dtype
:return: Packed tensor with int32 dtype optimized for storage :return: Packed tensor with int32 dtype optimized for storage
""" """
assert weight.dim( assert weight.dim() == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." assert weight.dtype in [torch.int8, torch.int32], (
assert weight.dtype in [ f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
torch.int8, torch.int32 )
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
if weight.dtype == torch.int32: if weight.dtype == torch.int32:
assert weight.shape[ assert weight.shape[-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
-1] % 8 == 0, "the last dim of weight needs to be divided by 8." packed_weight = torch_npu.npu_convert_weight_to_int4pack(weight.flatten(0, 1))
packed_weight = torch_npu.npu_convert_weight_to_int4pack( packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], -1)
weight.flatten(0, 1))
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
-1)
else: else:
assert weight.shape[ assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
packed_weight = weight.view(torch.int32).contiguous() packed_weight = weight.view(torch.int32).contiguous()
return packed_weight return packed_weight
@@ -115,8 +109,7 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32 self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get( self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
"group_size", 32)
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
def get_weight( def get_weight(
@@ -125,22 +118,23 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
hidden_sizes: int, hidden_sizes: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: ) -> dict[str, Any]:
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}" assert intermediate_size_per_partition % self.pack_factor == 0, (
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}" f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} "
f"can be divided by `pack_factor` {self.pack_factor}"
)
assert hidden_sizes % self.pack_factor == 0, (
f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
)
param_dict = {} param_dict = {}
param_dict["w13_weight_packed"] = torch.empty( param_dict["w13_weight_packed"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.pack_factor, dtype=torch.int32
2 * intermediate_size_per_partition, )
hidden_sizes // self.pack_factor,
dtype=torch.int32)
param_dict["w2_weight_packed"] = torch.empty( param_dict["w2_weight_packed"] = torch.empty(
num_experts, num_experts, hidden_sizes, intermediate_size_per_partition // self.pack_factor, dtype=torch.int32
hidden_sizes, )
intermediate_size_per_partition // self.pack_factor,
dtype=torch.int32)
return param_dict return param_dict
@@ -150,38 +144,31 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
hidden_sizes: int, hidden_sizes: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: ) -> dict[str, Any]:
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}" assert intermediate_size_per_partition % self.group_size == 0, (
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}" f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} "
f"can be divided by `group_size` {self.group_size}"
)
assert hidden_sizes % self.group_size == 0, (
f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
)
param_dict = {} param_dict = {}
param_dict["w13_weight_scale"] = torch.empty( param_dict["w13_weight_scale"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.bfloat16
2 * intermediate_size_per_partition, )
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty( param_dict["w2_weight_scale"] = torch.empty(
num_experts, num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.bfloat16
hidden_sizes, )
intermediate_size_per_partition // self.group_size, param_dict["w13_weight_shape"] = torch.empty(num_experts, 2, dtype=torch.int32)
dtype=torch.bfloat16) param_dict["w2_weight_shape"] = torch.empty(num_experts, 2, dtype=torch.int32)
param_dict["w13_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w2_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w13_weight_offset"] = torch.zeros( param_dict["w13_weight_offset"] = torch.zeros(
num_experts, num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.bfloat16
2 * intermediate_size_per_partition, )
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_offset"] = torch.zeros( param_dict["w2_weight_offset"] = torch.zeros(
num_experts, num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.bfloat16
hidden_sizes, )
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
return param_dict return param_dict
@@ -194,21 +181,22 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
topk_group: Optional[int] = None, topk_group: int | None = None,
num_expert_group: Optional[int] = None, num_expert_group: int | None = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True, is_prefill: bool = True,
enable_force_load_balance: bool = True, enable_force_load_balance: bool = True,
log2phy: Optional[torch.Tensor] = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[ assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" "Number of global experts mismatch (excluding redundancy)"
)
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
hidden_states=x, hidden_states=x,
@@ -221,7 +209,8 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts) global_num_experts=global_num_experts,
)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
@@ -241,38 +230,40 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get("mc2_mask"),
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.transpose_weight: if self.transpose_weight:
w13_shape = layer.w13_weight_packed.data.shape w13_shape = layer.w13_weight_packed.data.shape
w2_shape = layer.w2_weight_packed.data.shape w2_shape = layer.w2_weight_packed.data.shape
unpacked_w13_weight = (unpack_from_int32( unpacked_w13_weight = (
layer.w13_weight_packed.data.flatten(0, 1), unpack_from_int32(
torch.Size([ layer.w13_weight_packed.data.flatten(0, 1),
w13_shape[0] * w13_shape[1], torch.Size([w13_shape[0] * w13_shape[1], w13_shape[2] * self.pack_factor]),
w13_shape[2] * self.pack_factor self.num_bits,
]), )
self.num_bits, .view(w13_shape[0], w13_shape[1], -1)
).view(w13_shape[0], w13_shape[1], .transpose(1, 2)
-1).transpose(1, 2).contiguous().int()) .contiguous()
unpacked_w2_weight = (unpack_from_int32( .int()
layer.w2_weight_packed.data.flatten(0, 1), )
torch.Size([ unpacked_w2_weight = (
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor unpack_from_int32(
]), layer.w2_weight_packed.data.flatten(0, 1),
self.num_bits, torch.Size([w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor]),
).view(w2_shape[0], w2_shape[1], self.num_bits,
-1).transpose(1, 2).contiguous().int()) )
.view(w2_shape[0], w2_shape[1], -1)
.transpose(1, 2)
.contiguous()
.int()
)
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight) layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight) layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2).contiguous()
1, 2).contiguous() layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose( layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(1, 2).contiguous()
1, 2).contiguous() layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(1, 2).contiguous()
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
1, 2).contiguous()

View File

@@ -16,7 +16,7 @@
# #
import math import math
from typing import Any, Dict, Optional, Tuple from typing import Any
import torch import torch
import torch_npu import torch_npu
@@ -31,8 +31,7 @@ def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
"""Pack int4 weights for NPU.""" """Pack int4 weights for NPU."""
original_device = weight_tensor.device original_device = weight_tensor.device
weight_tensor_npu = weight_tensor.npu() weight_tensor_npu = weight_tensor.npu()
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack( weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(weight_tensor_npu.to(torch.int32), inner_k_tiles=1)
weight_tensor_npu.to(torch.int32), inner_k_tiles=1)
return weight_int4_packed.to(original_device) return weight_int4_packed.to(original_device)
@@ -58,22 +57,14 @@ def batched_kronecker_quant(
left_trans: torch.Tensor, left_trans: torch.Tensor,
right_trans: torch.Tensor, right_trans: torch.Tensor,
clip_ratio: float, clip_ratio: float,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Batched Kronecker quantization with batch size limit handling.""" """Batched Kronecker quantization with batch size limit handling."""
batch_tokens = x.shape[0] batch_tokens = x.shape[0]
if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE: if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE:
return torch_npu.npu_kronecker_quant(x, return torch_npu.npu_kronecker_quant(x, left_trans, right_trans, clip_ratio=clip_ratio, dst_dtype=torch.int32)
left_trans,
right_trans,
clip_ratio=clip_ratio,
dst_dtype=torch.int32)
x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0) x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0)
processed_chunks = [ processed_chunks = [
torch_npu.npu_kronecker_quant(chunk, torch_npu.npu_kronecker_quant(chunk, left_trans, right_trans, clip_ratio=clip_ratio, dst_dtype=torch.int32)
left_trans,
right_trans,
clip_ratio=clip_ratio,
dst_dtype=torch.int32)
for chunk in x_chunks for chunk in x_chunks
] ]
quantized_list, scale_list = zip(*processed_chunks) quantized_list, scale_list = zip(*processed_chunks)
@@ -88,36 +79,29 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization. This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization.
- Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8 and packed to int32 during loading - Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8 and packed to int32 during loading
- Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for distribution smoothing - Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for
- Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights distribution smoothing
- Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded
from external weights
""" """
input_size = 0 input_size = 0
def __init__(self): def __init__(self):
self.sym = True self.sym = True
def get_weight(self, input_size: int, output_size: int, def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
if input_size % 8 != 0: if input_size % 8 != 0:
raise ValueError( raise ValueError(f"input_size ({input_size}) must be divisible by 8 for int4 packing")
f"input_size ({input_size}) must be divisible by 8 for int4 packing"
)
AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size
params_dict = { params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict return params_dict
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {} params_dict = {}
left_trans_dim, right_trans_dim = get_decompose_dim( left_trans_dim, right_trans_dim = get_decompose_dim(AscendW4A4FlatQuantDynamicLinearMethod.input_size)
AscendW4A4FlatQuantDynamicLinearMethod.input_size) params_dict["left_trans"] = torch.empty(left_trans_dim, left_trans_dim, dtype=params_dtype)
params_dict["left_trans"] = torch.empty(left_trans_dim, params_dict["right_trans"] = torch.empty(right_trans_dim, right_trans_dim, dtype=params_dtype)
left_trans_dim,
dtype=params_dtype)
params_dict["right_trans"] = torch.empty(right_trans_dim,
right_trans_dim,
dtype=params_dtype)
params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32) params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32)
return params_dict return params_dict
@@ -125,22 +109,18 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
self, self,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: ) -> dict[str, Any]:
params_dict = {} params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32)
1, params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32)
dtype=torch.float32)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=torch.float32)
return params_dict return params_dict
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = 0, tp_rank: int | None = 0,
) -> torch.Tensor: ) -> torch.Tensor:
original_dtype = x.dtype original_dtype = x.dtype
input_shape = x.shape input_shape = x.shape
@@ -156,18 +136,18 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
right_trans_matched = layer.right_trans.to(original_dtype) right_trans_matched = layer.right_trans.to(original_dtype)
x_reshaped = x.view(-1, left_dim, right_dim) x_reshaped = x.view(-1, left_dim, right_dim)
x_quantized_int4, activation_scale = batched_kronecker_quant( x_quantized_int4, activation_scale = batched_kronecker_quant(
x_reshaped, left_trans_matched, right_trans_matched, x_reshaped, left_trans_matched, right_trans_matched, layer.aclnn_clip_ratio
layer.aclnn_clip_ratio) )
x_quantized_reshaped = x_quantized_int4.view(-1, x_quantized_reshaped = x_quantized_int4.view(-1, left_dim * right_dim // 8)
left_dim * right_dim // 8)
pertoken_scale = activation_scale.view(-1).to(torch.float32) pertoken_scale = activation_scale.view(-1).to(torch.float32)
output = torch_npu.npu_quant_matmul(x_quantized_reshaped, output = torch_npu.npu_quant_matmul(
layer.weight_packed.t(), x_quantized_reshaped,
layer.weight_scale.view(-1).to( layer.weight_packed.t(),
torch.float32), layer.weight_scale.view(-1).to(torch.float32),
pertoken_scale=pertoken_scale, pertoken_scale=pertoken_scale,
bias=None, bias=None,
output_dtype=original_dtype) output_dtype=original_dtype,
)
output = output.view(*input_shape[:-1], -1) output = output.view(*input_shape[:-1], -1)
if bias is not None: if bias is not None:
output = output + bias.to(original_dtype) output = output + bias.to(original_dtype)
@@ -176,15 +156,11 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
# NOTE: Currently, w4a4 can't support weight nz # NOTE: Currently, w4a4 can't support weight nz
weight_packed = pack_int4_weights(layer.weight.data) weight_packed = pack_int4_weights(layer.weight.data)
layer.register_parameter( layer.register_parameter("weight_packed", torch.nn.Parameter(weight_packed, requires_grad=False))
'weight_packed',
torch.nn.Parameter(weight_packed, requires_grad=False))
del layer.weight del layer.weight
layer.weight_scale.data = layer.weight_scale.data.to(torch.float32) layer.weight_scale.data = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.to(torch.float32)
layer.left_trans = torch.nn.Parameter( layer.left_trans = torch.nn.Parameter(layer.left_trans.data.t().contiguous())
layer.left_trans.data.t().contiguous())
layer.right_trans = torch.nn.Parameter(layer.right_trans.data) layer.right_trans = torch.nn.Parameter(layer.right_trans.data)
layer.clip_ratio = torch.nn.Parameter( layer.clip_ratio = torch.nn.Parameter(layer.clip_ratio.data.to(torch.float32))
layer.clip_ratio.data.to(torch.float32))
layer.aclnn_clip_ratio = layer.clip_ratio.item() layer.aclnn_clip_ratio = layer.clip_ratio.item()

View File

@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Dict, Optional from typing import Any
import torch import torch
import torch_npu import torch_npu
@@ -37,7 +37,7 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
self.transpose_weight = True self.transpose_weight = True
self.rotation_type = None self.rotation_type = None
def set_rotation_config(self, prefix: str, metadata: Dict) -> Optional[str]: def set_rotation_config(self, prefix: str, metadata: dict) -> str | None:
"""Set rotation config based on prefix and metadata.""" """Set rotation config based on prefix and metadata."""
layer_idx = prefix.split(".")[2] layer_idx = prefix.split(".")[2]
if prefix.endswith("o_proj"): if prefix.endswith("o_proj"):
@@ -50,34 +50,22 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
return "kronecker_rotation" return "kronecker_rotation"
return None return None
def get_weight(self, input_size: int, output_size: int, def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict return params_dict
def get_perchannel_param(self, output_size: int, def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {} params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32)
1, params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32)
dtype=torch.float32)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=torch.float32)
if self.rotation_type == "heads_rotation": if self.rotation_type == "heads_rotation":
params_dict["heads_rotation"] = torch.zeros((64, 64), params_dict["heads_rotation"] = torch.zeros((64, 64), dtype=torch.float32)
dtype=torch.float32)
if self.rotation_type == "kronecker_rotation": if self.rotation_type == "kronecker_rotation":
params_dict["kronecker_rotation_n"] = torch.zeros( params_dict["kronecker_rotation_n"] = torch.zeros((160, 160), dtype=torch.float32)
(160, 160), dtype=torch.float32) params_dict["kronecker_rotation_m"] = torch.zeros((160, 160), dtype=torch.float32)
params_dict["kronecker_rotation_m"] = torch.zeros(
(160, 160), dtype=torch.float32)
return params_dict return params_dict
def apply_rotation(self, layer: torch.nn.Module, def apply_rotation(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
x: torch.Tensor) -> torch.Tensor:
"""Apply rotation transformation to input tensor.""" """Apply rotation transformation to input tensor."""
init_shape = x.shape init_shape = x.shape
dtype = x.dtype dtype = x.dtype
@@ -100,8 +88,8 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = 0, tp_rank: int | None = 0,
) -> torch.Tensor: ) -> torch.Tensor:
dtype = x.dtype dtype = x.dtype
x, pertoken_scale = torch_npu.npu_dynamic_quant(x, dst_type=torch.quint4x2) x, pertoken_scale = torch_npu.npu_dynamic_quant(x, dst_type=torch.quint4x2)
@@ -113,14 +101,14 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
scale=layer.weight_scale.data.view(-1), scale=layer.weight_scale.data.view(-1),
pertoken_scale=pertoken_scale, pertoken_scale=pertoken_scale,
bias=None, bias=None,
output_dtype=dtype) output_dtype=dtype,
)
if bias is not None: if bias is not None:
output = output + bias.to(dtype) output = output + bias.to(dtype)
return output return output
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale.data = layer.weight_scale.data.to(torch.float32) layer.weight_scale.data = layer.weight_scale.data.to(torch.float32)
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32))
layer.weight.data.to(torch.int32))
if self.transpose_weight: if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(-1, -2) layer.weight.data = layer.weight.data.transpose(-1, -2)

View File

@@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Callable, Dict, Optional from collections.abc import Callable
from typing import Any
import numpy as np import numpy as np
import torch import torch
@@ -27,7 +28,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import maybe_trans_nz, COMPRESSED_TENSORS_METHOD from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme from .registry import register_scheme
@@ -39,17 +40,15 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
def __init__(self): def __init__(self):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get( self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256)
"group_size", 256) quant_version = vllm_config.quant_config.quant_description.get("version", "0")
quant_version = vllm_config.quant_config.quant_description.get(
"version", "0")
self.new_quant_version = quant_version == "1.0.0" self.new_quant_version = quant_version == "1.0.0"
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def get_weight(self, input_size: int, output_size: int, def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Create weight parameters. """Create weight parameters.
For new quantization version (double int4 pack into int8), the output dimension For new quantization version (double int4 pack into int8), the output dimension
@@ -62,40 +61,26 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
# double int4 pack into int8: output dimension is compressed # double int4 pack into int8: output dimension is compressed
pack_factor = 2 pack_factor = 2
actual_output_size = output_size // pack_factor actual_output_size = output_size // pack_factor
params_dict["weight"] = torch.empty(actual_output_size, params_dict["weight"] = torch.empty(actual_output_size, input_size, dtype=torch.int8)
input_size,
dtype=torch.int8)
# Add packing information for vLLM's weight_loader # Add packing information for vLLM's weight_loader
params_dict["_packed_dim"] = 0 params_dict["_packed_dim"] = 0
params_dict["_packed_factor"] = pack_factor params_dict["_packed_factor"] = pack_factor
else: else:
params_dict["weight"] = torch.empty(output_size, params_dict["weight"] = torch.empty(output_size, input_size, dtype=torch.int8)
input_size,
dtype=torch.int8)
return params_dict return params_dict
def get_pergroup_param(self, def get_pergroup_param(
input_size: int, self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
output_size: int, ) -> dict[str, Any]:
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
"""Create per-group quantization parameters.""" """Create per-group quantization parameters."""
params_dict = {} params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
1, params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
dtype=params_dtype) params_dict["weight_scale_second"] = torch.empty(output_size, input_size // self.group_size, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, params_dict["weight_offset_second"] = torch.empty(
1, output_size, input_size // self.group_size, dtype=params_dtype
dtype=params_dtype) )
params_dict["weight_scale_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
params_dict["weight_offset_second"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=params_dtype)
# NOTE: In w4a8 quantization implementation, # NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16], # for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16],
@@ -103,16 +88,13 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
if self.new_quant_version: if self.new_quant_version:
scale_bias_dim = 16 if layer_type == "row" else 1 scale_bias_dim = 16 if layer_type == "row" else 1
params_dict["scale_bias"] = torch.empty(output_size, params_dict["scale_bias"] = torch.empty(output_size, scale_bias_dim, dtype=torch.float32)
scale_bias_dim,
dtype=torch.float32)
return params_dict return params_dict
@staticmethod @staticmethod
def process_scale_second(weight: torch.Tensor, def process_scale_second(
scale: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, per_group_scale: torch.Tensor, is_new_quant: bool = False
per_group_scale: torch.Tensor, ):
is_new_quant: bool = False):
"""Process the scale for second-level quantization. """Process the scale for second-level quantization.
Args: Args:
@@ -133,8 +115,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
bias = None bias = None
if not is_new_quant: if not is_new_quant:
weight_high = weight.to(torch.float32).reshape( weight_high = weight.to(torch.float32).reshape(group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
group_num, -1, n) * per_group_scale.reshape(group_num, 1, n)
weight_high = weight_high.reshape(k, n) weight_high = weight_high.reshape(k, n)
bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0)
# NOTE: scale_bias is not used currently # NOTE: scale_bias is not used currently
@@ -148,8 +129,8 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = None, tp_rank: int | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch_npu.npu_weight_quant_batchmatmul( return torch_npu.npu_weight_quant_batchmatmul(
x, x,
@@ -161,8 +142,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
def process_weights_after_loading(self, layer: torch.nn.Module): def process_weights_after_loading(self, layer: torch.nn.Module):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = maybe_trans_nz(layer.weight.data) layer.weight.data = maybe_trans_nz(layer.weight.data)
layer.weight_scale.data = layer.weight_scale.data.flatten().to( layer.weight_scale.data = layer.weight_scale.data.flatten().to(torch.float32)
torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten() layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight_scale_second.data, scale_bias = self.process_scale_second( layer.weight_scale_second.data, scale_bias = self.process_scale_second(
layer.weight.data, layer.weight.data,
@@ -187,15 +167,14 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
if self.new_quant_version: if self.new_quant_version:
# weights on disk are already in packed int4 format # weights on disk are already in packed int4 format
# pack 4 int8(int4*2) to int32 # pack 4 int8(int4*2) to int32
assert layer.weight.data.shape[-1] % 4 == 0, \ assert layer.weight.data.shape[-1] % 4 == 0, (
f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}"
layer.weight.data = layer.weight.data.view( )
torch.int32).contiguous() layer.weight.data = layer.weight.data.view(torch.int32).contiguous()
else: else:
# weights are not compressed # weights are not compressed
# need to be packed via npu_convert_weight_to_int4pack # need to be packed via npu_convert_weight_to_int4pack
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32))
layer.weight.data.to(torch.int32))
@register_scheme("W4A8_DYNAMIC", "moe") @register_scheme("W4A8_DYNAMIC", "moe")
@@ -209,69 +188,56 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
self.ep_group = get_ep_group() self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get( self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256)
"group_size", 256)
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
self.is_per_channel_weight = self.group_size == 0 self.is_per_channel_weight = self.group_size == 0
quant_version = vllm_config.quant_config.quant_description.get( quant_version = vllm_config.quant_config.quant_description.get("version", "0")
"version", "0")
# NOTE: new quantize weights: 2 int4 pack into int8 # NOTE: new quantize weights: 2 int4 pack into int8
self.new_quant_version = quant_version == "1.0.0" self.new_quant_version = quant_version == "1.0.0"
self.quant_method = vllm_config.quant_config.quant_description.get( self.quant_method = vllm_config.quant_config.quant_description.get("ascend_quant_method", "")
"ascend_quant_method", "")
if self.quant_method == COMPRESSED_TENSORS_METHOD: if self.quant_method == COMPRESSED_TENSORS_METHOD:
self.weight_strategy = vllm_config.quant_config.quant_description.get( self.weight_strategy = vllm_config.quant_config.quant_description.get("weight_strategy", "group")
"weight_strategy", "group")
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
if self.new_quant_version and self.tp_size > 16: if self.new_quant_version and self.tp_size > 16:
raise ValueError( raise ValueError("The current weight does not support moe part tp>16.")
"The current weight does not support moe part tp>16.")
try: try:
device_group = get_mc2_group().device_group device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group # TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group) local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu")) backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name( self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
local_rank)
except AttributeError: except AttributeError:
self.moe_all_to_all_group_name = "" self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int, def get_weight(
intermediate_size_per_partition: int, hidden_sizes: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
params_dtype: torch.dtype) -> Dict[str, Any]: ) -> dict[str, Any]:
if self.quant_method == COMPRESSED_TENSORS_METHOD: if self.quant_method == COMPRESSED_TENSORS_METHOD:
return self.get_weight_compressed_tensors( return self.get_weight_compressed_tensors(
num_experts, intermediate_size_per_partition, num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
hidden_sizes, params_dtype) )
else: else:
return self.get_weight_modelslim( return self.get_weight_modelslim(num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype)
num_experts, intermediate_size_per_partition,
hidden_sizes, params_dtype)
def get_weight_compressed_tensors(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
def get_weight_compressed_tensors(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {} param_dict = {}
E = num_experts E = num_experts
H = hidden_sizes H = hidden_sizes
IN = intermediate_size_per_partition IN = intermediate_size_per_partition
g = self.group_size
param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, dtype=torch.int8)
dtype=torch.int8) param_dict["w2_weight"] = torch.empty(E, H, IN, dtype=torch.int8)
param_dict["w2_weight"] = torch.empty(E, H, IN,
dtype=torch.int8)
return param_dict return param_dict
def get_weight_modelslim(
def get_weight_modelslim(self, num_experts: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
intermediate_size_per_partition: int, hidden_sizes: int, ) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {} param_dict = {}
if self.new_quant_version: if self.new_quant_version:
w13_output_size = intermediate_size_per_partition w13_output_size = intermediate_size_per_partition
@@ -280,33 +246,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
w13_output_size = 2 * intermediate_size_per_partition w13_output_size = 2 * intermediate_size_per_partition
w2_output_size = hidden_sizes w2_output_size = hidden_sizes
param_dict["w13_weight"] = torch.empty(num_experts, param_dict["w13_weight"] = torch.empty(num_experts, w13_output_size, hidden_sizes, dtype=torch.int8)
w13_output_size, param_dict["w2_weight"] = torch.empty(
hidden_sizes, num_experts, w2_output_size, intermediate_size_per_partition, dtype=torch.int8
dtype=torch.int8) )
param_dict["w2_weight"] = torch.empty(num_experts,
w2_output_size,
intermediate_size_per_partition,
dtype=torch.int8)
return param_dict return param_dict
def get_dynamic_quant_param(self, num_experts: int, def get_dynamic_quant_param(
intermediate_size_per_partition: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
hidden_sizes: int, ) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
if self.quant_method == COMPRESSED_TENSORS_METHOD: if self.quant_method == COMPRESSED_TENSORS_METHOD:
return self.get_dynamic_quant_param_compressed_tensors( return self.get_dynamic_quant_param_compressed_tensors(
num_experts, intermediate_size_per_partition, num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
hidden_sizes, params_dtype) )
else: else:
return self.get_dynamic_quant_param_modelslim( return self.get_dynamic_quant_param_modelslim(
num_experts, intermediate_size_per_partition, num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
hidden_sizes, params_dtype) )
def get_dynamic_quant_param_compressed_tensors(self, num_experts: int, def get_dynamic_quant_param_compressed_tensors(
intermediate_size_per_partition: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
hidden_sizes: int, ) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {} param_dict = {}
E = num_experts E = num_experts
@@ -318,72 +278,48 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
def _n_scale_cols(in_features: int) -> int: def _n_scale_cols(in_features: int) -> int:
return 1 if g <= 0 else (in_features // g) return 1 if g <= 0 else (in_features // g)
param_dict["w13_weight_scale"] = torch.empty( param_dict["w13_weight_scale"] = torch.empty(E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), dtype=torch.bfloat16)
dtype=torch.bfloat16)
return param_dict return param_dict
def get_dynamic_quant_param_modelslim(self, num_experts: int, def get_dynamic_quant_param_modelslim(
intermediate_size_per_partition: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
hidden_sizes: int, ) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {} param_dict = {}
param_dict["w13_weight_scale"] = torch.empty( param_dict["w13_weight_scale"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
2 * intermediate_size_per_partition, )
1,
dtype=torch.float32)
param_dict["w13_weight_offset"] = torch.empty( param_dict["w13_weight_offset"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
2 * intermediate_size_per_partition, )
1,
dtype=torch.float32)
param_dict["w2_weight_scale"] = torch.empty(num_experts, param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32)
hidden_sizes, param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32)
1,
dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=torch.float32)
if not self.is_per_channel_weight: if not self.is_per_channel_weight:
param_dict["w13_weight_scale_second"] = torch.empty( param_dict["w13_weight_scale_second"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32
2 * intermediate_size_per_partition, )
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w13_weight_offset_second"] = torch.empty( param_dict["w13_weight_offset_second"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32
2 * intermediate_size_per_partition, )
hidden_sizes // self.group_size,
dtype=torch.float32)
param_dict["w2_weight_scale_second"] = torch.empty( param_dict["w2_weight_scale_second"] = torch.empty(
num_experts, num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32
hidden_sizes, )
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
param_dict["w2_weight_offset_second"] = torch.empty( param_dict["w2_weight_offset_second"] = torch.empty(
num_experts, num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32
hidden_sizes, )
intermediate_size_per_partition // self.group_size,
dtype=torch.float32)
if self.new_quant_version: if self.new_quant_version:
param_dict["w13_scale_bias"] = torch.empty( param_dict["w13_scale_bias"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
2 * intermediate_size_per_partition, )
1, param_dict["w2_scale_bias"] = torch.empty(
dtype=torch.float32) num_experts, hidden_sizes, 16 // self.tp_size, dtype=torch.float32
param_dict["w2_scale_bias"] = torch.empty(num_experts, )
hidden_sizes,
16 // self.tp_size,
dtype=torch.float32)
return param_dict return param_dict
@@ -396,21 +332,22 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
topk_group: Optional[int] = None, topk_group: int | None = None,
num_expert_group: Optional[int] = None, num_expert_group: int | None = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True, is_prefill: bool = True,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[ assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" "Number of global experts mismatch (excluding redundancy)"
)
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
@@ -424,18 +361,17 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts) global_num_experts=global_num_experts,
)
# this is a naive implementation for experts load balance so as # this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank. # to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs. # currently it is only activated when doing profile runs.
if enable_force_load_balance: if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), random_matrix = torch.rand(
global_num_experts - topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
global_redundant_expert_num, )
device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
@@ -446,25 +382,23 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
w2=[layer.w2_weight], w2=[layer.w2_weight],
w1_scale=[layer.w13_weight_scale], w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale], w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr( w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
layer, "w13_scale_bias") else None, w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(
layer, "w2_scale_bias") else None,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_int4_w4a8=True, use_int4_w4a8=True,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get("mc2_mask"),
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale): def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
scale = scale.transpose(1, 2).contiguous() scale = scale.transpose(1, 2).contiguous()
if self.is_per_channel_weight: if self.is_per_channel_weight:
scale_np = scale.cpu().numpy() scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32 scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype( scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
np.int64)).npu()
return scale_uint64_tensor, None return scale_uint64_tensor, None
per_group_scale = per_group_scale.transpose(1, 2).contiguous() per_group_scale = per_group_scale.transpose(1, 2).contiguous()
group_num, k, n = weight.shape group_num, k, n = weight.shape
@@ -475,32 +409,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
group_num, quantgroup_num, n = per_group_scale.shape group_num, quantgroup_num, n = per_group_scale.shape
bias = None bias = None
if not self.new_quant_version: if not self.new_quant_version:
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ weight_high = weight.to(torch.float32).reshape(
per_group_scale.reshape([group_num, quantgroup_num, 1, n]) [group_num, quantgroup_num, -1, n]
) * per_group_scale.reshape([group_num, quantgroup_num, 1, n])
weight_high = weight_high.reshape([group_num, k, n]) weight_high = weight_high.reshape([group_num, k, n])
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
scale_fp32 = (scale * per_group_scale).to(torch.float16).to( scale_fp32 = (scale * per_group_scale).to(torch.float16).to(torch.float32)
torch.float32)
scale_fp32_np = scale_fp32.cpu().numpy() scale_fp32_np = scale_fp32.cpu().numpy()
scale_fp32_np.dtype = np.uint32 scale_fp32_np.dtype = np.uint32
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), dtype=np.uint32)
dtype=np.uint32)
sscale_uint64[..., ::2] = scale_fp32_np sscale_uint64[..., ::2] = scale_fp32_np
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), dtype=np.int64).copy()
dtype=np.int64).copy() sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(group_num, quantgroup_num, n)
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
group_num, quantgroup_num, n)
sscale_uint64_tensor = sscale_uint64_tensor.npu() sscale_uint64_tensor = sscale_uint64_tensor.npu()
return sscale_uint64_tensor, bias return sscale_uint64_tensor, bias
def update_bias(self, layer, w13_bias, w2_bias): def update_bias(self, layer, w13_bias, w2_bias):
if self.new_quant_version: if self.new_quant_version:
layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose( layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)
1, 2).contiguous().sum(axis=1) layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1)
layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(
1, 2).contiguous().sum(axis=1)
else: else:
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias) layer.register_parameter("w13_scale_bias", w13_scale_bias)
@@ -510,13 +439,12 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
def pack_to_int32(self, weight: torch.Tensor): def pack_to_int32(self, weight: torch.Tensor):
if self.new_quant_version: if self.new_quant_version:
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
assert weight.shape[ assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
return weight.view(torch.int32).contiguous() return weight.view(torch.int32).contiguous()
else: else:
return torch_npu.npu_quantize(weight.to(torch.float32), return torch_npu.npu_quantize(
torch.tensor([1.]).npu(), None, weight.to(torch.float32), torch.tensor([1.0]).npu(), None, torch.quint4x2, -1, False
torch.quint4x2, -1, False) )
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.quant_method == COMPRESSED_TENSORS_METHOD: if self.quant_method == COMPRESSED_TENSORS_METHOD:
@@ -524,23 +452,18 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
else: else:
self.process_weights_after_loading_modelslim(layer) self.process_weights_after_loading_modelslim(layer)
def process_weights_after_loading_compressed_tensors(self, layer): def process_weights_after_loading_compressed_tensors(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose( layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
def process_scale_compressed_tensors(scale: torch.Tensor): def process_scale_compressed_tensors(scale: torch.Tensor):
scale = scale.transpose(1, 2).to(torch.float32).contiguous() scale = scale.transpose(1, 2).to(torch.float32).contiguous()
scale_np = scale.cpu().numpy() scale_np = scale.cpu().numpy()
scale_np.dtype = np.uint32 scale_np.dtype = np.uint32
scale_uint64_tensor = torch.from_numpy(scale_np.astype( scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu()
np.int64)).npu()
return scale_uint64_tensor return scale_uint64_tensor
def update_bias_compressed_tensors(weight: torch.Tensor, def update_bias_compressed_tensors(weight: torch.Tensor, scale: torch.Tensor, strategy: str):
scale: torch.Tensor, strategy:str):
group_num, k, n = weight.shape group_num, k, n = weight.shape
scale = scale.transpose(1, 2).contiguous() scale = scale.transpose(1, 2).contiguous()
scale = scale.reshape(group_num, -1, n) scale = scale.reshape(group_num, -1, n)
@@ -548,8 +471,9 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
bias = None bias = None
if strategy == "group": if strategy == "group":
tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * scale.reshape(
scale.reshape([group_num, quantgroup_num, 1, n]) [group_num, quantgroup_num, 1, n]
)
tmp = tmp.reshape([group_num, k, n]) tmp = tmp.reshape([group_num, k, n])
bias = 8 * tmp.sum(axis=1) bias = 8 * tmp.sum(axis=1)
elif strategy == "channel": elif strategy == "channel":
@@ -558,18 +482,13 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
raise ValueError(f"Unsupported weight strategy: {strategy}") raise ValueError(f"Unsupported weight strategy: {strategy}")
return bias return bias
w13_bias = update_bias_compressed_tensors(layer.w13_weight.data, w13_bias = update_bias_compressed_tensors(
layer.w13_weight_scale.data, layer.w13_weight.data, layer.w13_weight_scale.data, self.weight_strategy
self.weight_strategy) )
w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, layer.w2_weight_scale.data, self.weight_strategy)
layer.w2_weight_scale.data,
self.weight_strategy)
layer.w13_weight_scale.data = process_scale_compressed_tensors(
layer.w13_weight_scale.data)
layer.w2_weight_scale.data = process_scale_compressed_tensors(
layer.w2_weight_scale.data)
layer.w13_weight_scale.data = process_scale_compressed_tensors(layer.w13_weight_scale.data)
layer.w2_weight_scale.data = process_scale_compressed_tensors(layer.w2_weight_scale.data)
w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.register_parameter("w13_scale_bias", w13_scale_bias) layer.register_parameter("w13_scale_bias", w13_scale_bias)
@@ -583,21 +502,19 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
def process_weights_after_loading_modelslim(self, layer): def process_weights_after_loading_modelslim(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose( layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( w13_weight_scale_second = (
layer, "w13_weight_scale_second") else None layer.w13_weight_scale_second.data if hasattr(layer, "w13_weight_scale_second") else None
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( )
layer, "w2_weight_scale_second") else None w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(layer, "w2_weight_scale_second") else None
layer.w13_weight_scale.data, w13_bias = self.process_scale( layer.w13_weight_scale.data, w13_bias = self.process_scale(
layer.w13_weight, layer.w13_weight_scale.data, layer.w13_weight, layer.w13_weight_scale.data, w13_weight_scale_second
w13_weight_scale_second) )
layer.w2_weight_scale.data, w2_bias = self.process_scale( layer.w2_weight_scale.data, w2_bias = self.process_scale(
layer.w2_weight, layer.w2_weight_scale.data, layer.w2_weight, layer.w2_weight_scale.data, w2_weight_scale_second
w2_weight_scale_second) )
if hasattr(layer, "w13_weight_scale_second"): if hasattr(layer, "w13_weight_scale_second"):
# scale_second is no longer used, release this part of the memory # scale_second is no longer used, release this part of the memory
del layer.w13_weight_scale_second del layer.w13_weight_scale_second

View File

@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Dict, Optional from typing import Any
import torch import torch
import torch_npu import torch_npu
@@ -41,39 +41,34 @@ class AscendW8A16LinearMethod(AscendLinearScheme):
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype = torch.bfloat16, params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]: ) -> dict[str, Any]:
params_dict = { params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict return params_dict
def get_perchannel_param( def get_perchannel_param(
self, self,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: ) -> dict[str, Any]:
params_dict = {} params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
1, params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
return params_dict return params_dict
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = 0, tp_rank: int | None = 0,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch_npu.npu_weight_quant_batchmatmul( output = torch_npu.npu_weight_quant_batchmatmul(
x=x, x=x,
weight=layer.weight, weight=layer.weight,
antiquant_scale=layer.weight_scale, antiquant_scale=layer.weight_scale,
antiquant_offset=layer.weight_offset, antiquant_offset=layer.weight_offset,
bias=bias) bias=bias,
)
return output return output
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):

View File

@@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Callable, Dict, Optional from collections.abc import Callable
from typing import Any
import torch import torch
import torch_npu import torch_npu
@@ -28,8 +29,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
zero_experts_compute)
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .base import AscendLinearScheme, AscendMoEScheme, QuantType
@@ -39,9 +39,10 @@ from .registry import register_scheme
def scale_from_float_to_int64(scale): def scale_from_float_to_int64(scale):
"""Convert float32 scale to int64 representation.""" """Convert float32 scale to int64 representation."""
import numpy as np import numpy as np
scale = torch.from_numpy( scale = torch.from_numpy(
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64)
dtype=np.int32).astype(np.int64)).to(scale.device) ).to(scale.device)
return scale return scale
@@ -56,33 +57,26 @@ class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
def __init__(self): def __init__(self):
pass pass
def get_weight(self, input_size: int, output_size: int, def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict return params_dict
def get_perchannel_param( def get_perchannel_param(
self, self,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: ) -> dict[str, Any]:
params_dict = {} params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
1, params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
return params_dict return params_dict
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = 0, tp_rank: int | None = 0,
) -> torch.Tensor: ) -> torch.Tensor:
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x) quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
output = torch_npu.npu_quant_matmul( output = torch_npu.npu_quant_matmul(
@@ -116,9 +110,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.use_aclgraph = (vllm_config.compilation_config.mode self.use_aclgraph = (
== CompilationMode.VLLM_COMPILE vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager) and not vllm_config.model_config.enforce_eager
)
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
@@ -130,49 +125,34 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
# TODO: Try local_rank = ep_group.rank_in_group # TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group) local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu")) backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name( self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
local_rank)
except AttributeError: except AttributeError:
self.moe_all_to_all_group_name = "" self.moe_all_to_all_group_name = ""
def get_weight(self, num_experts: int, def get_weight(
intermediate_size_per_partition: int, hidden_sizes: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
params_dtype: torch.dtype) -> Dict[str, Any]: ) -> dict[str, Any]:
param_dict = {} param_dict = {}
param_dict["w13_weight"] = torch.empty(num_experts, param_dict["w13_weight"] = torch.empty(
2 * num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.int8
intermediate_size_per_partition, )
hidden_sizes, param_dict["w2_weight"] = torch.empty(
dtype=torch.int8) num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.int8
param_dict["w2_weight"] = torch.empty(num_experts, )
hidden_sizes,
intermediate_size_per_partition,
dtype=torch.int8)
return param_dict return param_dict
def get_dynamic_quant_param(self, num_experts: int, def get_dynamic_quant_param(
intermediate_size_per_partition: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
hidden_sizes: int, ) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {} param_dict = {}
param_dict["w13_weight_scale"] = torch.empty( param_dict["w13_weight_scale"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype
2 * intermediate_size_per_partition, )
1,
dtype=params_dtype)
param_dict["w13_weight_offset"] = torch.empty( param_dict["w13_weight_offset"] = torch.empty(
num_experts, num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype
2 * intermediate_size_per_partition, )
1, param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype)
dtype=params_dtype) param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype)
param_dict["w2_weight_scale"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
param_dict["w2_weight_offset"] = torch.empty(num_experts,
hidden_sizes,
1,
dtype=params_dtype)
return param_dict return param_dict
def apply( def apply(
@@ -184,25 +164,26 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
topk_group: Optional[int] = None, topk_group: int | None = None,
num_expert_group: Optional[int] = None, num_expert_group: int | None = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True, is_prefill: bool = True,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None, log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
pertoken_scale: Optional[Any] = None, pertoken_scale: Any | None = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None) zero_expert_type = getattr(layer, "zero_expert_type", None)
if zero_expert_num == 0 or zero_expert_type is None: if zero_expert_num == 0 or zero_expert_type is None:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \ assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)" "Number of global experts mismatch (excluding redundancy)"
)
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context() fc3_context = get_flash_common3_context()
@@ -222,7 +203,8 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts) global_num_experts=global_num_experts,
)
assert topk_ids is not None assert topk_ids is not None
assert topk_weights is not None assert topk_weights is not None
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
@@ -237,12 +219,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
# to avoid accumulating too much tokens on a single rank. # to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs. # currently it is only activated when doing profile runs.
if enable_force_load_balance: if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), random_matrix = torch.rand(
global_num_experts - topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
global_redundant_expert_num, )
device=topk_ids.device) topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
assert topk_weights is not None assert topk_weights is not None
topk_weights = topk_weights.to(self.in_dtype) topk_weights = topk_weights.to(self.in_dtype)
@@ -259,9 +239,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
w2 = [layer.w2_weight] w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale] w2_scale = [layer.w2_weight_scale]
fused_scale_flag = (get_forward_context().moe_comm_type fused_scale_flag = (
== MoECommType.FUSED_MC2 get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1) and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
)
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
pertoken_scale=pertoken_scale, pertoken_scale=pertoken_scale,
@@ -275,54 +256,35 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get("mc2_mask"),
)
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result final_hidden_states += zero_expert_result
return final_hidden_states return final_hidden_states
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
layer.w13_weight.data = layer.w13_weight.data.transpose( layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous()
1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
# TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant` # TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant`
# can only support weight nz. # can only support weight nz.
layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast( layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1)
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(torch.float32)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1)
layer.w13_weight_scale.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
torch.float32)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
layer.fused_w1_scale = scale_from_float_to_int64( layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
layer.w13_weight_scale.data) layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(
layer.w2_weight_scale.data)
if self.dynamic_eplb: if self.dynamic_eplb:
layer.w13_weight_list = [ layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)]
weight.clone() layer.w2_weight_list = [weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)]
for weight in layer.w13_weight.data.unbind(dim=0)
]
layer.w2_weight_list = [
weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
]
layer.w13_weight_scale_fp32_list = [ layer.w13_weight_scale_fp32_list = [
weight.clone() weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
]
layer.w2_weight_scale_list = [
weight.clone()
for weight in layer.w2_weight_scale.data.unbind(dim=0)
] ]
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
del layer.w13_weight del layer.w13_weight
del layer.w2_weight del layer.w2_weight
del layer.w13_weight_scale del layer.w13_weight_scale

View File

@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Dict, Optional from typing import Any
import torch import torch
import torch_npu import torch_npu
@@ -33,43 +33,32 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
The activation is dynamically quantized to FP8 (E4M3FN format) with The activation is dynamically quantized to FP8 (E4M3FN format) with
microscaling, and weights are stored in FP8 format with per-group scales. microscaling, and weights are stored in FP8 format with per-group scales.
""" """
model_dtype = None model_dtype = None
def __init__(self): def __init__(self):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get( self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
"group_size", 32)
def get_weight(self, input_size: int, output_size: int, def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)}
params_dict = {
"weight":
torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)
}
return params_dict return params_dict
def get_pergroup_param(self, def get_pergroup_param(
input_size: int, self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
output_size: int, ) -> dict[str, Any]:
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
params_dict = {} params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, params_dict["weight_scale"] = torch.empty(output_size, input_size // self.group_size, dtype=torch.uint8)
input_size //
self.group_size,
dtype=torch.uint8)
return params_dict return params_dict
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = 0, tp_rank: int | None = 0,
) -> torch.Tensor: ) -> torch.Tensor:
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn)
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(
x, dst_type=torch.float8_e4m3fn)
pertoken_scale = dynamic_scale pertoken_scale = dynamic_scale
output_dtype = x.dtype output_dtype = x.dtype
@@ -82,13 +71,13 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
pertoken_scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias, bias=bias,
output_dtype=output_dtype, output_dtype=output_dtype,
group_sizes=[1, 1, self.group_size]) group_sizes=[1, 1, self.group_size],
)
return output return output
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
n_dim, k_dim = layer.weight_scale.data.shape n_dim, k_dim = layer.weight_scale.data.shape
layer.weight_scale.data = layer.weight_scale.data.reshape( layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2)
n_dim, k_dim // 2, 2)
layer.weight.data = layer.weight.data.transpose(0, 1) layer.weight.data = layer.weight.data.transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)

View File

@@ -22,15 +22,14 @@ for prefill and decode phases:
- Decode (KV consumer): Uses static W8A8 quantization - Decode (KV consumer): Uses static W8A8 quantization
""" """
from typing import Any, Dict, Optional from typing import Any
import torch import torch
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from .base import AscendLinearScheme from .base import AscendLinearScheme
from .registry import register_scheme from .registry import register_scheme
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, from .w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod
AscendW8A8DynamicLinearMethod)
from .w8a8_static import AscendW8A8LinearMethod from .w8a8_static import AscendW8A8LinearMethod
@@ -53,31 +52,27 @@ class AscendW8A8PDMixLinearMethod(AscendLinearScheme):
self._dynamic_method = AscendW8A8DynamicLinearMethod() self._dynamic_method = AscendW8A8DynamicLinearMethod()
kv_transfer_config = get_current_vllm_config().kv_transfer_config kv_transfer_config = get_current_vllm_config().kv_transfer_config
self._is_kv_consumer = (kv_transfer_config is not None self._is_kv_consumer = kv_transfer_config is not None and kv_transfer_config.is_kv_consumer
and kv_transfer_config.is_kv_consumer)
def get_weight(self, input_size: int, output_size: int, def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dtype: torch.dtype) -> Dict[str, Any]: return self._static_method.get_weight(input_size, output_size, params_dtype)
return self._static_method.get_weight(input_size, output_size,
params_dtype)
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
return self._static_method.get_pertensor_param(params_dtype) return self._static_method.get_pertensor_param(params_dtype)
def get_perchannel_param( def get_perchannel_param(
self, self,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return self._static_method.get_perchannel_param( return self._static_method.get_perchannel_param(output_size, params_dtype)
output_size, params_dtype)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = 0, tp_rank: int | None = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if layer.is_kv_consumer: if layer.is_kv_consumer:
return self._static_method.apply(layer, x, bias, tp_rank) return self._static_method.apply(layer, x, bias, tp_rank)
@@ -92,26 +87,15 @@ class AscendW8A8PDMixLinearMethod(AscendLinearScheme):
@register_scheme("W8A8_MIX", "moe") @register_scheme("W8A8_MIX", "moe")
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod): class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
def get_dynamic_quant_param(
def get_dynamic_quant_param(self, num_experts: int, self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
intermediate_size_per_partition: int, ) -> dict[str, Any]:
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = super().get_dynamic_quant_param( param_dict = super().get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_sizes, num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype
params_dtype) )
param_dict["w2_deq_scale"] = torch.empty(num_experts, param_dict["w2_deq_scale"] = torch.empty(num_experts, hidden_sizes, dtype=torch.float32)
hidden_sizes, param_dict["w13_deq_scale"] = torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32)
dtype=torch.float32) param_dict["w2_input_offset"] = torch.empty(num_experts, 1, dtype=torch.int8)
param_dict["w13_deq_scale"] = torch.empty( param_dict["w13_input_offset"] = torch.empty(num_experts, 1, dtype=torch.int8)
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32)
param_dict["w2_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
param_dict["w13_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
return param_dict return param_dict

View File

@@ -15,14 +15,16 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Dict, Optional from typing import Any
import torch import torch
import torch_npu import torch_npu
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType, from vllm_ascend.utils import (
get_ascend_device_type, COMPRESSED_TENSORS_METHOD,
get_weight_prefetch_method, maybe_trans_nz) get_weight_prefetch_method,
maybe_trans_nz,
)
from .base import AscendLinearScheme from .base import AscendLinearScheme
from .registry import register_scheme from .registry import register_scheme
@@ -44,13 +46,11 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype = torch.bfloat16, params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]: ) -> dict[str, Any]:
params_dict = { params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict return params_dict
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {} params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
@@ -60,29 +60,23 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
self, self,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: ) -> dict[str, Any]:
params_dict = {} params_dict = {}
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
if params_dtype == torch.bfloat16: if params_dtype == torch.bfloat16:
params_dict["deq_scale"] = torch.empty(output_size, params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
dtype=torch.float32)
elif params_dtype == torch.float16: elif params_dtype == torch.float16:
params_dict["deq_scale"] = torch.empty(output_size, params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
dtype=torch.int64) params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_scale"] = torch.empty(output_size, params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
1,
dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size,
1,
dtype=params_dtype)
return params_dict return params_dict
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
tp_rank: Optional[int] = 0, tp_rank: int | None = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if x.dtype != torch.int8: if x.dtype != torch.int8:
layer_cls_name = layer.__class__.__name__ layer_cls_name = layer.__class__.__name__
@@ -95,15 +89,15 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
start_flag=x, start_flag=x,
) )
try: try:
quant_comm_config = getattr(layer, "_quant_comm_config") quant_comm_config = layer._quant_comm_config
except AttributeError: except AttributeError:
quant_comm_config = {} quant_comm_config = {}
comm_fn = quant_comm_config.get("communication_fn") comm_fn = quant_comm_config.get("communication_fn")
enable_flashcomm2_quant_comm = comm_fn is not None and ( enable_flashcomm2_quant_comm = comm_fn is not None and (
"o_proj" in layer.prefix or "out_proj" in layer.prefix) "o_proj" in layer.prefix or "out_proj" in layer.prefix
)
if enable_flashcomm2_quant_comm: if enable_flashcomm2_quant_comm:
quant_input_x = x.contiguous().view( quant_input_x = x.contiguous().view(-1, layer.aclnn_input_scale_reciprocal.size(0))
-1, layer.aclnn_input_scale_reciprocal.size(0))
quant_x = torch.ops.vllm.quantize( quant_x = torch.ops.vllm.quantize(
quant_input_x, quant_input_x,
layer.aclnn_input_scale, layer.aclnn_input_scale,
@@ -132,7 +126,7 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
quant_bias = layer.quant_bias if tp_rank == 0 else None quant_bias = layer.quant_bias if tp_rank == 0 else None
try: try:
ascend_quant_method = getattr(layer, "ascend_quant_method") ascend_quant_method = layer.ascend_quant_method
except AttributeError: except AttributeError:
ascend_quant_method = "" ascend_quant_method = ""
if ascend_quant_method == COMPRESSED_TENSORS_METHOD: if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
@@ -150,14 +144,14 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1] expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter( layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor), layer.input_scale.data.repeat(expanding_factor), requires_grad=False
requires_grad=False) )
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor), layer.input_scale.data.repeat(expanding_factor), requires_grad=False
requires_grad=False) )
layer.aclnn_input_offset = torch.nn.Parameter( layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor), layer.input_offset.data.repeat(expanding_factor), requires_grad=False
requires_grad=False).to(layer.aclnn_input_scale.dtype) ).to(layer.aclnn_input_scale.dtype)
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = maybe_trans_nz(layer.weight.data) layer.weight.data = maybe_trans_nz(layer.weight.data)
@@ -166,5 +160,4 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
ascend_quant_method = getattr(layer, "ascend_quant_method", "") ascend_quant_method = getattr(layer, "ascend_quant_method", "")
if ascend_quant_method == COMPRESSED_TENSORS_METHOD: if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
deq_scale = layer.input_scale.data * layer.weight_scale.data deq_scale = layer.input_scale.data * layer.weight_scale.data
layer.deq_scale = torch.nn.Parameter(deq_scale, layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False)
requires_grad=False)

View File

@@ -21,20 +21,18 @@ This module provides the AscendModelSlimConfig class for parsing quantization
configs generated by the ModelSlim tool, along with model-specific mappings. configs generated by the ModelSlim tool, along with model-specific mappings.
""" """
from collections.abc import Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Optional
import torch import torch
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import \ from vllm.model_executor.layers.quantization import register_quantization_config
register_quantization_config from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
@@ -45,7 +43,7 @@ logger = init_logger(__name__)
# key: model_type # key: model_type
# value: orig_to_new_prefix # value: orig_to_new_prefix
QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = { QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = {
"qwen3_vl_moe": { "qwen3_vl_moe": {
"visual.": "model.visual.", "visual.": "model.visual.",
"language_model.lm_head.": "lm_head.", "language_model.lm_head.": "lm_head.",
@@ -60,7 +58,7 @@ QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = {
# key: model_type # key: model_type
# value: dict of fused module name -> list of original module names # value: dict of fused module name -> list of original module names
packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { packed_modules_model_mapping: dict[str, dict[str, list[str]]] = {
"qwen3_moe": { "qwen3_moe": {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@@ -71,52 +69,44 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"gate_proj", "gate_proj",
"up_proj", "up_proj",
], ],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}, },
"deepseek_v2": { "deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"deepseek_v3": { "deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"pangu_ultra_moe": { "pangu_ultra_moe": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"kimi_k2": { "kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"deepseek_v32": { "deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have # NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT. # MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": { "deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}, },
"pangu_ultra_moe_mtp": { "pangu_ultra_moe_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"qwen3_next": { "qwen3_next": {
"qkv_proj": [ "qkv_proj": [
@@ -126,8 +116,7 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
], ],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"], "in_proj": ["in_proj_qkvz", "in_proj_ba"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}, },
"qwen2_5_vl": { "qwen2_5_vl": {
"qkv_proj": [ "qkv_proj": [
@@ -150,8 +139,7 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"gate_proj", "gate_proj",
"up_proj", "up_proj",
], ],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}, },
"glm4_moe": { "glm4_moe": {
"qkv_proj": [ "qkv_proj": [
@@ -163,20 +151,17 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"gate_proj", "gate_proj",
"up_proj", "up_proj",
], ],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}, },
"glm4_moe_lite": { "glm4_moe_lite": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"longcat_flash": { "longcat_flash": {
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
"experts": "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
}, },
"minimax_m2": { "minimax_m2": {
"qkv_proj": [ "qkv_proj": [
@@ -184,12 +169,12 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"k_proj", "k_proj",
"v_proj", "v_proj",
], ],
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"] "experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"],
} },
} }
def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]: def get_packed_modules_mapping(model_type: str) -> dict[str, list[str]]:
"""Get packed modules mapping for a model type. """Get packed modules mapping for a model type.
Args: Args:
@@ -202,7 +187,7 @@ def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]:
return packed_modules_model_mapping.get(model_type, {}) return packed_modules_model_mapping.get(model_type, {})
def get_prefix_mapping(model_type: str) -> Dict[str, str]: def get_prefix_mapping(model_type: str) -> dict[str, str]:
"""Get prefix mapping for a model type. """Get prefix mapping for a model type.
Args: Args:
@@ -216,8 +201,8 @@ def get_prefix_mapping(model_type: str) -> Dict[str, str]:
def get_linear_quant_type( def get_linear_quant_type(
quant_description: Dict[str, Any], prefix: str, quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any]
packed_modules_mapping: Dict[str, Any]) -> Optional[str]: ) -> str | None:
"""Determine the quantization type for a linear layer. """Determine the quantization type for a linear layer.
Args: Args:
@@ -232,11 +217,10 @@ def get_linear_quant_type(
if proj_name in packed_modules_mapping: if proj_name in packed_modules_mapping:
quant_type = None quant_type = None
shard_prefixes = [ shard_prefixes = [
prefix.replace(proj_name, shard_proj_name) prefix.replace(proj_name, shard_proj_name) for shard_proj_name in packed_modules_mapping[proj_name]
for shard_proj_name in packed_modules_mapping[proj_name]
] ]
for shard_prefix in shard_prefixes: for shard_prefix in shard_prefixes:
shard_quant_type = quant_description[shard_prefix + '.weight'] shard_quant_type = quant_description[shard_prefix + ".weight"]
if quant_type is None: if quant_type is None:
quant_type = shard_quant_type quant_type = shard_quant_type
@@ -244,18 +228,19 @@ def get_linear_quant_type(
raise ValueError( raise ValueError(
f"Not all shards of {prefix} are quantized with same quant type." f"Not all shards of {prefix} are quantized with same quant type."
f"Shard {proj_name} uses {shard_quant_type}, but another shard" f"Shard {proj_name} uses {shard_quant_type}, but another shard"
f"use {quant_type}. Please check quantization config.") f"use {quant_type}. Please check quantization config."
)
else: else:
quant_type = quant_description[prefix + '.weight'] quant_type = quant_description[prefix + ".weight"]
return quant_type return quant_type
def get_quant_type_for_layer( def get_quant_type_for_layer(
quant_description: Dict[str, Any], quant_description: dict[str, Any],
prefix: str, prefix: str,
layer_type: str, layer_type: str,
packed_modules_mapping: Optional[Dict[str, packed_modules_mapping: dict[str, Any] | None = None,
Any]] = None) -> Optional[str]: ) -> str | None:
"""Determine the quantization type for a layer. """Determine the quantization type for a layer.
Args: Args:
@@ -270,19 +255,18 @@ def get_quant_type_for_layer(
if packed_modules_mapping is None: if packed_modules_mapping is None:
packed_modules_mapping = dict() packed_modules_mapping = dict()
# Attention # Attention
if layer_type == "attention" and 'fa_quant_type' in quant_description.keys( if layer_type == "attention" and "fa_quant_type" in quant_description:
): return quant_description["fa_quant_type"]
return quant_description['fa_quant_type']
# Linear / MoE # Linear / MoE
return get_linear_quant_type(quant_description, prefix, return get_linear_quant_type(quant_description, prefix, packed_modules_mapping)
packed_modules_mapping)
def create_scheme_for_layer( def create_scheme_for_layer(
quant_description: Dict[str, Any], quant_description: dict[str, Any],
prefix: str, prefix: str,
layer_type: str, layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None): packed_modules_mapping: dict[str, Any] | None = None,
):
"""Create a quantization scheme instance for a layer. """Create a quantization scheme instance for a layer.
Args: Args:
@@ -295,21 +279,17 @@ def create_scheme_for_layer(
An instance of the appropriate quantization scheme class. An instance of the appropriate quantization scheme class.
""" """
logger.info_once("Using the vLLM Ascend modelslim Quantization now!") logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
quant_type = get_quant_type_for_layer(quant_description, prefix, quant_type = get_quant_type_for_layer(quant_description, prefix, layer_type, packed_modules_mapping)
layer_type, packed_modules_mapping)
if quant_type is None: if quant_type is None:
raise ValueError( raise ValueError(f"Could not determine quantization type for layer {prefix}.")
f"Could not determine quantization type for layer {prefix}.")
# Use registry to get scheme class # Use registry to get scheme class
scheme_cls = get_scheme_class(quant_type, layer_type) scheme_cls = get_scheme_class(quant_type, layer_type)
if scheme_cls is not None: if scheme_cls is not None:
return scheme_cls() return scheme_cls()
raise NotImplementedError( raise NotImplementedError(f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}.")
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
)
@register_quantization_config(ASCEND_QUANTIZATION_METHOD) @register_quantization_config(ASCEND_QUANTIZATION_METHOD)
@@ -321,13 +301,13 @@ class AscendModelSlimConfig(QuantizationConfig):
quantized using the ModelSlim tool. quantized using the ModelSlim tool.
""" """
def __init__(self, quant_config: Dict[str, Any]): def __init__(self, quant_config: dict[str, Any]):
super().__init__() super().__init__()
self.quant_description = quant_config self.quant_description = quant_config
# TODO(whx): remove this adaptation after adding "shared_head" # TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM. # to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {} extra_quant_dict = {}
for k in self.quant_description.keys(): for k in self.quant_description:
if "shared_head" in k: if "shared_head" in k:
new_k = k.replace(".shared_head.", ".") new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k] extra_quant_dict[new_k] = self.quant_description[k]
@@ -344,25 +324,23 @@ class AscendModelSlimConfig(QuantizationConfig):
return ASCEND_QUANTIZATION_METHOD return ASCEND_QUANTIZATION_METHOD
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16] return [torch.int8, torch.float16, torch.bfloat16]
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
raise NotImplementedError( raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.')
"Ascend hardware dose not support \"get_min_capability\" feature.")
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> list[str]:
return ["quant_model_description.json"] return ["quant_model_description.json"]
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "AscendModelSlimConfig": def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig":
return cls(config) return cls(config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None:
user_quant) -> Optional[str]:
if hf_quant_cfg is not None: if hf_quant_cfg is not None:
quant_method = hf_quant_cfg.get("quant_method", None) quant_method = hf_quant_cfg.get("quant_method", None)
if not quant_method and torch.npu.is_available(): if not quant_method and torch.npu.is_available():
@@ -373,15 +351,17 @@ class AscendModelSlimConfig(QuantizationConfig):
# TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented # TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping: if prefix_mapping:
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping)
orig_to_new_prefix=prefix_mapping)
return hf_to_vllm_mapper._map_name(prefix) return hf_to_vllm_mapper._map_name(prefix)
return prefix return prefix
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]:
prefix: str) -> Optional["QuantizeMethodBase"]: from .method_adapters import (
from .method_adapters import (AscendEmbeddingMethod, AscendFusedMoEMethod, AscendEmbeddingMethod,
AscendKVCacheMethod, AscendLinearMethod) AscendFusedMoEMethod,
AscendKVCacheMethod,
AscendLinearMethod,
)
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_config.model_type model_type = vllm_config.model_config.hf_config.model_type
@@ -390,81 +370,67 @@ class AscendModelSlimConfig(QuantizationConfig):
# Adapt to Minimax architecture: update layer names to MoE convention # Adapt to Minimax architecture: update layer names to MoE convention
prefix = prefix.replace("mlp", "block_sparse_moe") prefix = prefix.replace("mlp", "block_sparse_moe")
# Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts') # Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts')
parts = prefix.split('.') parts = prefix.split(".")
if "experts" in parts and len(parts) > 2: if "experts" in parts and len(parts) > 2:
exp_idx = parts.index("experts") exp_idx = parts.index("experts")
if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit(): if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit():
parts = parts[:exp_idx + 1] parts = parts[: exp_idx + 1]
prefix = ".".join(parts) prefix = ".".join(parts)
if model_type in packed_modules_model_mapping: if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[ self.packed_modules_mapping = packed_modules_model_mapping[model_type]
model_type]
prefix = self.quant_prefix_mapper(model_type, prefix) prefix = self.quant_prefix_mapper(model_type, prefix)
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
if vllm_version_is("v0.15.0"): if vllm_version_is("v0.15.0"):
from vllm.attention.layer import Attention # type: ignore from vllm.attention.layer import Attention # type: ignore
else: else:
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
if prefix.startswith("language_model"): if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1] prefix = prefix.split(".", 1)[-1]
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix, if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
self.packed_modules_mapping):
# Delayed import to avoid circular import # Delayed import to avoid circular import
from vllm_ascend.ops.linear import \ from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
AscendUnquantizedLinearMethod
return AscendUnquantizedLinearMethod() return AscendUnquantizedLinearMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix, scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
"linear",
self.packed_modules_mapping)
return AscendLinearMethod(scheme) return AscendLinearMethod(scheme)
elif isinstance(layer, Attention) and \ elif (
'fa_quant_type' in self.quant_description.keys() and \ isinstance(layer, Attention)
self.quant_description['fa_quant_type'] is not None: and "fa_quant_type" in self.quant_description
scheme = create_scheme_for_layer(self.quant_description, prefix, and self.quant_description["fa_quant_type"] is not None
"attention", ):
self.packed_modules_mapping) scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
return AscendKVCacheMethod(scheme) return AscendKVCacheMethod(scheme)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix, if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
self.packed_modules_mapping):
# Delayed import to avoid circular import # Delayed import to avoid circular import
from vllm_ascend.ops.fused_moe.fused_moe import \ from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
AscendUnquantizedFusedMoEMethod
return AscendUnquantizedFusedMoEMethod(layer.moe_config) return AscendUnquantizedFusedMoEMethod(layer.moe_config)
scheme = create_scheme_for_layer(self.quant_description, prefix, scheme = create_scheme_for_layer(self.quant_description, prefix, "moe", self.packed_modules_mapping)
"moe",
self.packed_modules_mapping)
return AscendFusedMoEMethod(scheme, layer.moe_config) return AscendFusedMoEMethod(scheme, layer.moe_config)
elif isinstance(layer, VocabParallelEmbedding): elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix, if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod() return UnquantizedEmbeddingMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix, scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
"linear",
self.packed_modules_mapping)
return AscendEmbeddingMethod(scheme) return AscendEmbeddingMethod(scheme)
return None return None
def is_layer_skipped_ascend( def is_layer_skipped_ascend(self, prefix: str, fused_mapping: Mapping[str, list[str]] = MappingProxyType({})):
self,
prefix: str,
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1] proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping: if proj_name in fused_mapping:
shard_prefixes = [ shard_prefixes = [
prefix.replace(proj_name, shard_proj_name) prefix.replace(proj_name, shard_proj_name) for shard_proj_name in fused_mapping[proj_name]
for shard_proj_name in fused_mapping[proj_name]
] ]
is_skipped = None is_skipped = None
for shard_prefix in shard_prefixes: for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix + is_shard_skipped = self.quant_description[shard_prefix + ".weight"] == "FLOAT"
'.weight'] == "FLOAT"
if is_skipped is None: if is_skipped is None:
is_skipped = is_shard_skipped is_skipped = is_shard_skipped
@@ -472,12 +438,13 @@ class AscendModelSlimConfig(QuantizationConfig):
raise ValueError( raise ValueError(
f"Detected some but not all shards of {prefix} " f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers " "are quantized. All shards of fused layers "
"to have the same precision.") "to have the same precision."
)
else: else:
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
assert is_skipped is not None assert is_skipped is not None
return is_skipped return is_skipped
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> list[str]:
return [] return []

View File

@@ -1,18 +1,23 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch import torch
from vllm.triton_utils import HAS_TRITON, triton from vllm.triton_utils import HAS_TRITON, triton
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN, from vllm.v1.sample.rejection_sampler import (
PLACEHOLDER_TOKEN_ID, GREEDY_TEMPERATURE,
generate_uniform_probs) MAX_SPEC_LEN,
PLACEHOLDER_TOKEN_ID,
generate_uniform_probs,
)
from vllm_ascend.ops.triton.reject_sample import ( from vllm_ascend.ops.triton.reject_sample import (
cal_grid_and_block_size, expand_triton, cal_grid_and_block_size,
expand_triton,
rejection_greedy_sample_with_triton, rejection_greedy_sample_with_triton,
rejection_random_sample_block_verify_kernel, rejection_random_sample_block_verify_kernel,
rejection_random_sample_kernel, sample_recovered_tokens_kernel) rejection_random_sample_kernel,
sample_recovered_tokens_kernel,
)
from vllm_ascend.sample.sampler import apply_top_k_top_p from vllm_ascend.sample.sampler import apply_top_k_top_p
@@ -83,7 +88,7 @@ def rejection_sample(
# [batch_size] # [batch_size]
cu_num_draft_tokens: torch.Tensor, cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
draft_probs: Optional[torch.Tensor], draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
target_probs: torch.Tensor, target_probs: torch.Tensor,
# [batch_size, 1] # [batch_size, 1]
@@ -126,15 +131,20 @@ def rejection_sample(
# Rejection sampling for greedy sampling requests. # Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1) target_argmax = target_probs.argmax(dim=-1)
if HAS_TRITON: if HAS_TRITON:
rejection_greedy_sample_with_triton(output_token_ids, rejection_greedy_sample_with_triton(
num_draft_tokens, output_token_ids,
cu_num_draft_tokens, num_draft_tokens,
draft_token_ids, target_argmax, cu_num_draft_tokens,
bonus_token_ids, is_greedy, draft_token_ids,
max_spec_len, grid, block_size) target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
grid,
block_size,
)
else: else:
if min(num_draft_tokens) == 1 and max( if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and sampling_metadata.all_greedy:
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch( rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, output_token_ids,
draft_token_ids, draft_token_ids,
@@ -179,7 +189,7 @@ def rejection_sample(
if not using_block_verify: if not using_block_verify:
# Rejection sampling for random sampling requests. # Rejection sampling for random sampling requests.
if HAS_TRITON: if HAS_TRITON:
rejection_random_sample_kernel[(grid, )]( rejection_random_sample_kernel[(grid,)](
output_token_ids, output_token_ids,
cu_num_draft_tokens, cu_num_draft_tokens,
draft_token_ids, draft_token_ids,
@@ -214,7 +224,7 @@ def rejection_sample(
else: else:
# MagicMTP: Improving acceptance rate with Block Verify. # MagicMTP: Improving acceptance rate with Block Verify.
if HAS_TRITON: if HAS_TRITON:
rejection_random_sample_block_verify_kernel[(grid, )]( rejection_random_sample_block_verify_kernel[(grid,)](
output_token_ids, output_token_ids,
cu_num_draft_tokens, cu_num_draft_tokens,
draft_token_ids, draft_token_ids,
@@ -231,19 +241,20 @@ def rejection_sample(
BLOCK_SIZE=block_size, BLOCK_SIZE=block_size,
) )
else: else:
rejection_random_sample_block_verify_pytorch(output_token_ids, rejection_random_sample_block_verify_pytorch(
cu_num_draft_tokens, output_token_ids,
draft_token_ids, cu_num_draft_tokens,
draft_probs, draft_token_ids,
target_probs, draft_probs,
bonus_token_ids, target_probs,
recovered_token_ids, bonus_token_ids,
uniform_probs, recovered_token_ids,
is_greedy, uniform_probs,
max_spec_len, is_greedy,
vocab_size, max_spec_len,
IS_NGRAM=draft_probs vocab_size,
is None) IS_NGRAM=draft_probs is None,
)
return output_token_ids return output_token_ids
@@ -277,13 +288,7 @@ def expand_batch_to_tokens(
assert cu_num_tokens.shape[0] == batch_size assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens) expanded_x = x.new_empty(num_tokens)
if HAS_TRITON: if HAS_TRITON:
expand_triton(batch_size, expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens=MAX_SPEC_LEN)
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
max_num_tokens=MAX_SPEC_LEN)
else: else:
expand_pytorch( expand_pytorch(
expanded_x, expanded_x,
@@ -301,7 +306,7 @@ def sample_recovered_tokens(
num_draft_tokens: list[int], num_draft_tokens: list[int],
cu_num_draft_tokens: torch.Tensor, cu_num_draft_tokens: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor], draft_probs: torch.Tensor | None,
target_probs: torch.Tensor, target_probs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
device: torch.device, device: torch.device,
@@ -316,9 +321,7 @@ def sample_recovered_tokens(
) )
q.exponential_() q.exponential_()
num_draft_tensor = torch.tensor(num_draft_tokens, num_draft_tensor = torch.tensor(num_draft_tokens, pin_memory=True).to(device, non_blocking=True)
pin_memory=True).to(device,
non_blocking=True)
has_draft_mask = num_draft_tensor > 0 has_draft_mask = num_draft_tensor > 0
for i, generator in sampling_metadata.generators.items(): for i, generator in sampling_metadata.generators.items():
@@ -357,10 +360,10 @@ def sample_recovered_tokens(
def rejection_greedy_sample_spec_len_1_pytorch( def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2] output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens] draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens] target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size] bonus_token_ids, # [batch_size]
): ):
batch_size = output_token_ids.size(0) batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0) num_tokens = draft_token_ids.size(0)
@@ -368,73 +371,56 @@ def rejection_greedy_sample_spec_len_1_pytorch(
accept_req_mask = draft_token_ids == target_argmax accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1) bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1])
output_token_ids[:, 1])
def rejection_greedy_sample_pytorch( def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1] output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size] cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens] draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens] target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size] bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list draft_tokens_per_req, # [batch_size], list
max_spec_len, max_spec_len,
is_greedy=None, # [batch_size] or None is_greedy=None, # [batch_size] or None
): ):
batch_size = output_token_ids.size(0) batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0) num_tokens = draft_token_ids.size(0)
device = output_token_ids.device device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(device, non_blocking=True)
device, non_blocking=True)
if is_greedy is None: if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
start_indices = cu_num_draft_tokens - draft_tokens_per_req start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device) req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange( token_positions = torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
num_tokens, device=device) - start_indices[token_req_ids]
# Find the first mismatch position of each request. # Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax) mismatch_global = draft_token_ids != target_argmax
if max_spec_len == 0: if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size, first_mismatch_pos_per_req = torch.zeros(batch_size, dtype=torch.long, device=device)
dtype=torch.long,
device=device)
else: else:
# [bs, max_spec_len] # [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len), pos_matrix = torch.full((batch_size, max_spec_len), -1, dtype=torch.long, device=device)
-1,
dtype=torch.long,
device=device)
pos_matrix[token_req_ids, token_positions] = token_positions pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len), mismatch_matrix = torch.full((batch_size, max_spec_len), False, dtype=torch.bool, device=device)
False,
dtype=torch.bool,
device=device)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[no_mismatch_mask]
no_mismatch_mask]
# Copy matched target tokens into output. # Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
draft_tokens_per_req) copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1) copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1) greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[ output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(output_token_ids.dtype)
global_idx[final_copy_mask]].to(output_token_ids.dtype)
# Fill bonus token. # Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
>= draft_tokens_per_req)
if torch.any(needs_bonus): if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0] bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows] bonus_cols = draft_tokens_per_req[bonus_rows]
@@ -495,15 +481,12 @@ def rejection_random_sample_pytorch(
valid_mask = pos_indices < num_draft_per_batch[:, None] valid_mask = pos_indices < num_draft_per_batch[:, None]
global_token_indices = cu_start[:, None] + pos_indices global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp( global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
0, draft_token_ids.shape[0] - 1) draft_tokens = draft_token_ids[global_token_indices] # [batch_size, max_draft_len]
draft_tokens = draft_token_ids[
global_token_indices] # [batch_size, max_draft_len]
if IS_NGRAM: if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to( draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
device, non_blocking=True).expand_as(draft_tokens)
else: else:
flat_indices = global_token_indices.flatten() flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten() flat_draft_tokens = draft_tokens.flatten()
@@ -518,24 +501,21 @@ def rejection_random_sample_pytorch(
uniform_token_probs = uniform_probs[global_token_indices] uniform_token_probs = uniform_probs[global_token_indices]
recovered_tokens = recovered_token_ids[global_token_indices] recovered_tokens = recovered_token_ids[global_token_indices]
zero_threshold_cpu = torch.tensor([0.0], zero_threshold_cpu = torch.tensor([0.0], pin_memory=True, dtype=torch.float32)
pin_memory=True,
dtype=torch.float32)
zero_threshold = zero_threshold_cpu.to(device, non_blocking=True) zero_threshold = zero_threshold_cpu.to(device, non_blocking=True)
acceptance_condition = (draft_token_probs > zero_threshold) & ( acceptance_condition = (draft_token_probs > zero_threshold) & (
target_token_probs / draft_token_probs >= uniform_token_probs) target_token_probs / draft_token_probs >= uniform_token_probs
)
first_rejection = (~acceptance_condition) & valid_mask first_rejection = (~acceptance_condition) & valid_mask
default_pos_cpu = torch.full([batch_size, 1], default_pos_cpu = torch.full([batch_size, 1], max_draft_len, pin_memory=True)
max_draft_len,
pin_memory=True)
default_pos = default_pos_cpu.to(device, non_blocking=True) default_pos = default_pos_cpu.to(device, non_blocking=True)
first_reject_pos = torch.where( first_reject_pos = torch.where(
first_rejection.any(dim=1, keepdim=True), first_rejection.any(dim=1, keepdim=True), first_rejection.float().argmax(dim=1, keepdim=True), default_pos
first_rejection.float().argmax(dim=1, keepdim=True), default_pos) )
pos_mask = pos_indices >= first_reject_pos pos_mask = pos_indices >= first_reject_pos
should_skip = pos_mask & valid_mask should_skip = pos_mask & valid_mask
@@ -543,16 +523,17 @@ def rejection_random_sample_pytorch(
non_greedy_mask = ~is_greedy non_greedy_mask = ~is_greedy
update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip) update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip)
first_reject_mask = (pos_indices == first_reject_pos first_reject_mask = (pos_indices == first_reject_pos) & valid_mask & non_greedy_mask[:, None]
) & valid_mask & non_greedy_mask[:, None]
final_update_mask = update_mask | first_reject_mask final_update_mask = update_mask | first_reject_mask
final_tokens = torch.where( final_tokens = torch.where(
first_reject_mask, recovered_tokens, first_reject_mask,
torch.where(final_acceptance, draft_tokens, recovered_tokens,
output_token_ids[:, :max_draft_len])) torch.where(final_acceptance, draft_tokens, output_token_ids[:, :max_draft_len]),
)
output_token_ids[:, :max_draft_len] = torch.where( output_token_ids[:, :max_draft_len] = torch.where(
final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]) final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]
)
no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch
should_add_bonus = non_greedy_mask & no_rejection should_add_bonus = non_greedy_mask & no_rejection
@@ -561,8 +542,7 @@ def rejection_random_sample_pytorch(
seq_len = output_token_ids.shape[1] seq_len = output_token_ids.shape[1]
all_positions_cpu = torch.arange(seq_len, pin_memory=True) all_positions_cpu = torch.arange(seq_len, pin_memory=True)
all_positions = all_positions_cpu.to( all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] # [1, seq_len]
device, non_blocking=True)[None, :] # [1, seq_len]
batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1] batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1]
@@ -572,12 +552,11 @@ def rejection_random_sample_pytorch(
valid_bonus_pos = bonus_positions < (max_spec_len_device + 1) valid_bonus_pos = bonus_positions < (max_spec_len_device + 1)
final_bonus_mask = should_add_bonus & valid_bonus_pos final_bonus_mask = should_add_bonus & valid_bonus_pos
bonus_pos_match = (all_positions == batch_bonus_positions) bonus_pos_match = all_positions == batch_bonus_positions
bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None] bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None]
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len) bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len)
output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, output_token_ids)
output_token_ids)
def expand_pytorch( def expand_pytorch(
@@ -609,21 +588,16 @@ def expand_pytorch(
if batch_size == 0 or num_tokens == 0: if batch_size == 0 or num_tokens == 0:
return return
cu_start = torch.cat([ cu_start = torch.cat([torch.tensor([0], pin_memory=True).to(device, non_blocking=True), cu_num_tokens_ptr[:-1]])
torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
cu_num_tokens_ptr[:-1]
])
cu_end = cu_num_tokens_ptr cu_end = cu_num_tokens_ptr
token_indices = torch.arange(num_tokens, token_indices = torch.arange(num_tokens, device=device)[:, None] # [num_tokens, 1]
device=device)[:, None] # [num_tokens, 1]
cu_start_exp = cu_start[None, :] # [1, batch_size] cu_start_exp = cu_start[None, :] # [1, batch_size]
cu_end_exp = cu_end[None, :] # [1, batch_size] cu_end_exp = cu_end[None, :] # [1, batch_size]
in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp) in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp)
replaced_input = torch.where(input_ptr == replace_from, replace_to, replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float()
input_ptr).float()
token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input) token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input)
@@ -666,10 +640,12 @@ def sample_recovered_tokens_pytorch(
if num_tokens == 0: if num_tokens == 0:
return return
cu_start = torch.cat([ cu_start = torch.cat(
torch.tensor([0], pin_memory=True).to(device, non_blocking=True), [
cu_num_draft_tokens[:-1], torch.tensor([0], pin_memory=True).to(device, non_blocking=True),
]) cu_num_draft_tokens[:-1],
]
)
cu_end = cu_num_draft_tokens cu_end = cu_num_draft_tokens
token_indices = torch.arange(num_tokens, device=device) # [num_tokens] token_indices = torch.arange(num_tokens, device=device) # [num_tokens]
@@ -678,8 +654,7 @@ def sample_recovered_tokens_pytorch(
cu_start_expanded = cu_start[None, :] # [1, batch_size] cu_start_expanded = cu_start[None, :] # [1, batch_size]
cu_end_expanded = cu_end[None, :] # [1, batch_size] cu_end_expanded = cu_end[None, :] # [1, batch_size]
in_range_mask = (token_indices_expanded >= cu_start_expanded) & ( in_range_mask = (token_indices_expanded >= cu_start_expanded) & (token_indices_expanded < cu_end_expanded)
token_indices_expanded < cu_end_expanded)
token_to_batch = torch.argmax(in_range_mask.int(), dim=1) token_to_batch = torch.argmax(in_range_mask.int(), dim=1)
@@ -707,8 +682,7 @@ def sample_recovered_tokens_pytorch(
prob_over_q = prob / q_values_safe prob_over_q = prob / q_values_safe
prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, prob_over_q)
prob_over_q)
recovered_ids = torch.argmax(prob_over_q, dim=1) recovered_ids = torch.argmax(prob_over_q, dim=1)
@@ -742,14 +716,12 @@ def rejection_random_sample_block_verify_pytorch(
pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :] pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :]
valid_mask = pos_indices < num_draft_per_batch valid_mask = pos_indices < num_draft_per_batch
global_token_indices = cu_start[:, None] + pos_indices global_token_indices = cu_start[:, None] + pos_indices
global_token_indices = global_token_indices.clamp( global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1)
0, draft_token_ids.shape[0] - 1)
draft_tokens = draft_token_ids[global_token_indices] draft_tokens = draft_token_ids[global_token_indices]
if IS_NGRAM: if IS_NGRAM:
ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32)
draft_token_probs = ones_cpu.to( draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens)
device, non_blocking=True).expand_as(draft_tokens)
else: else:
flat_indices = global_token_indices.flatten() flat_indices = global_token_indices.flatten()
flat_draft_tokens = draft_tokens.flatten() flat_draft_tokens = draft_tokens.flatten()
@@ -772,27 +744,21 @@ def rejection_random_sample_block_verify_pytorch(
last_accept_pos = torch.where( last_accept_pos = torch.where(
legal_mask.any(dim=-1, keepdim=True), legal_mask.any(dim=-1, keepdim=True),
(max_spec_len - (max_spec_len - legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1),
legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1), -1,
-1) )
non_greedy_mask = (~is_greedy)[:, None] non_greedy_mask = (~is_greedy)[:, None]
accept_mask = (pos_indices accept_mask = (pos_indices <= last_accept_pos) & valid_mask & non_greedy_mask
<= last_accept_pos) & valid_mask & non_greedy_mask output_token_ids[:, :max_spec_len] = torch.where(accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
output_token_ids[:, :max_spec_len] = torch.where(
accept_mask, draft_tokens, output_token_ids[:, :max_spec_len])
reject_mask = (pos_indices reject_mask = (pos_indices == last_accept_pos + 1) & valid_mask & non_greedy_mask
== last_accept_pos + 1) & valid_mask & non_greedy_mask output_token_ids[:, :max_spec_len] = torch.where(reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
output_token_ids[:, :max_spec_len] = torch.where(
reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len])
bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask
all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True) all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True)
all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :]
bonus_pos_match = (all_positions == num_draft_per_batch) bonus_pos_match = all_positions == num_draft_per_batch
bonus_mask = bonus_mask & bonus_pos_match bonus_mask = bonus_mask & bonus_pos_match
bonus_values_expanded = bonus_token_ids.view(-1, 1).expand( bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, max_spec_len + 1)
-1, max_spec_len + 1) output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, output_token_ids)
output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded,
output_token_ids)

View File

@@ -35,7 +35,6 @@ def random_sample(
class AscendSampler(Sampler): class AscendSampler(Sampler):
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE): def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
# TODO: support logprobs_mode in vllm-ascend # TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode) super().__init__(logprobs_mode=logprobs_mode)
@@ -62,7 +61,6 @@ class AscendSampler(Sampler):
class AscendTopKTopPSampler(TopKTopPSampler): class AscendTopKTopPSampler(TopKTopPSampler):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.apply_top_k_top_p = apply_top_k_top_p self.apply_top_k_top_p = apply_top_k_top_p
@@ -135,4 +133,9 @@ def _apply_top_k_top_p_ascendc(
return logits return logits
return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p) return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p)
apply_top_k_top_p = _apply_top_k_top_p_ascendc if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] else _apply_top_k_top_p_pytorch
apply_top_k_top_p = (
_apply_top_k_top_p_ascendc
if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3]
else _apply_top_k_top_p_pytorch
)

View File

@@ -1,5 +1,3 @@
from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
from vllm.distributed import get_dcp_group, get_pcp_group from vllm.distributed import get_dcp_group, get_pcp_group
@@ -8,17 +6,18 @@ from vllm.v1.utils import CpuGpuBuffer
class BlockTable: class BlockTable:
def __init__(
def __init__(self, self,
block_size: int, block_size: int,
max_num_reqs: int, max_num_reqs: int,
max_num_blocks_per_req: int, max_num_blocks_per_req: int,
max_num_batched_tokens: int, max_num_batched_tokens: int,
pin_memory: bool, pin_memory: bool,
device: torch.device, device: torch.device,
kernel_sizes: Union[list[int], None] = None, kernel_sizes: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
num_speculative_tokens: int = 0): num_speculative_tokens: int = 0,
):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
@@ -28,8 +27,7 @@ class BlockTable:
try: try:
self.pcp_world_size = get_pcp_group().world_size self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group( self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_world_size > 1 else 0
).rank_in_group if self.pcp_world_size > 1 else 0
self.dcp_world_size = get_dcp_group().world_size self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError: except AssertionError:
@@ -49,42 +47,37 @@ class BlockTable:
# Find the first kernel size that divides physical_block_size evenly # Find the first kernel size that divides physical_block_size evenly
selected_kernel_size = None selected_kernel_size = None
for kernel_size in kernel_sizes: for kernel_size in kernel_sizes:
if kernel_size > 0 \ if kernel_size > 0 and self.physical_block_size % kernel_size == 0:
and self.physical_block_size % kernel_size == 0:
selected_kernel_size = kernel_size selected_kernel_size = kernel_size
break break
if selected_kernel_size is None: if selected_kernel_size is None:
raise ValueError( raise ValueError(
f"None of the kernel sizes {kernel_sizes} can divide " f"None of the kernel sizes {kernel_sizes} can divide "
f"physical block size {self.physical_block_size} evenly") f"physical block size {self.physical_block_size} evenly"
)
self.block_size = selected_kernel_size self.block_size = selected_kernel_size
self.logical_block_size = selected_kernel_size self.logical_block_size = selected_kernel_size
self.blocks_per_phys_block = (self.physical_block_size // self.blocks_per_phys_block = self.physical_block_size // self.logical_block_size
self.logical_block_size)
if self.blocks_per_phys_block > 1: if self.blocks_per_phys_block > 1:
self.use_hybrid_blocks = True self.use_hybrid_blocks = True
else: else:
self.use_hybrid_blocks = False self.use_hybrid_blocks = False
if self.use_hybrid_blocks: if self.use_hybrid_blocks:
logical_table_size = (max_num_blocks_per_req * logical_table_size = max_num_blocks_per_req * self.blocks_per_phys_block
self.blocks_per_phys_block)
else: else:
logical_table_size = max_num_blocks_per_req logical_table_size = max_num_blocks_per_req
duplicate_size = 1 duplicate_size = 1
if self.pcp_world_size * self.dcp_world_size > 1: if self.pcp_world_size * self.dcp_world_size > 1:
duplicate_size += num_speculative_tokens duplicate_size += num_speculative_tokens
self.block_table = self._make_buffer(max_num_reqs * duplicate_size, self.block_table = self._make_buffer(max_num_reqs * duplicate_size, logical_table_size, dtype=torch.int32)
logical_table_size,
dtype=torch.int32)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer( self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens + self.max_num_batched_tokens + 2 * self.pcp_world_size * self.max_num_reqs, dtype=torch.int32
2 * self.pcp_world_size * self.max_num_reqs, )
dtype=torch.int32)
self.kernel_sizes = kernel_sizes self.kernel_sizes = kernel_sizes
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
@@ -103,7 +96,7 @@ class BlockTable:
num_blocks = len(block_ids) num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx] start = self.num_blocks_per_row[row_idx]
self.block_table.np[row_idx, start:start + num_blocks] = block_ids self.block_table.np[row_idx, start : start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] += num_blocks self.num_blocks_per_row[row_idx] += num_blocks
def add_row(self, block_ids: list[int], row_idx: int) -> None: def add_row(self, block_ids: list[int], row_idx: int) -> None:
@@ -112,8 +105,7 @@ class BlockTable:
def move_row(self, src: int, tgt: int) -> None: def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src] num_blocks = self.num_blocks_per_row[src]
self.block_table.np[tgt, :num_blocks] = self.block_table.np[ self.block_table.np[tgt, :num_blocks] = self.block_table.np[src, :num_blocks]
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None: def swap_row(self, src: int, tgt: int) -> None:
@@ -124,8 +116,7 @@ class BlockTable:
self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]] self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]]
def compute_slot_mapping(self, req_indices: np.ndarray, def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
positions: np.ndarray) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2. # where K is the max_num_blocks_per_req and the block size is 2.
@@ -150,27 +141,30 @@ class BlockTable:
# (always needed with unified tensor) # (always needed with unified tensor)
# Each physical block is split into multiple logical blocks # Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this # The logical table has been expanded to accommodate this
block_table_indices = (req_indices * self.max_num_blocks_per_req * block_table_indices = (
self.blocks_per_phys_block + req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
logical_block_idx) )
block_numbers = self.block_table.np.ravel()[block_table_indices] block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local # Use virtual_block_size for mask calculation, which marks local
# tokens. # tokens.
virtual_block_offsets = positions % virtual_block_size virtual_block_offsets = positions % virtual_block_size
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
mask = (virtual_block_offsets // self.cp_kv_cache_interleave_size % mask = (
(self.dcp_world_size * virtual_block_offsets // self.cp_kv_cache_interleave_size % (self.dcp_world_size * self.pcp_world_size)
self.pcp_world_size) == self.current_rank) == self.current_rank
)
# Calculate local block_offsets # Calculate local block_offsets
block_offsets = virtual_block_offsets \ block_offsets = (
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) \ virtual_block_offsets
* self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size // (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping # Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local # Write final slots, use -1 for not-local
self.slot_mapping.np[:req_indices.shape[0]] = np.where( self.slot_mapping.np[: req_indices.shape[0]] = np.where(mask, slot_mapping, -1)
mask, slot_mapping, -1)
else: else:
assert self.kernel_sizes is not None assert self.kernel_sizes is not None
if self.block_size == self.kernel_sizes[0]: if self.block_size == self.kernel_sizes[0]:
@@ -183,15 +177,12 @@ class BlockTable:
# Each physical block is split into multiple logical blocks # Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this # The logical table has been expanded to accommodate this
block_table_indices = ( block_table_indices = (
req_indices * self.max_num_blocks_per_req * req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
self.blocks_per_phys_block + logical_block_idx) )
block_numbers = self.block_table.np.ravel( block_numbers = self.block_table.np.ravel()[block_table_indices]
)[block_table_indices]
block_offsets = positions % self.block_size block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size, np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping.np[: req_indices.shape[0]])
block_offsets,
out=self.slot_mapping.np[:req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None: def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs) self.block_table.copy_to_gpu(num_reqs)
@@ -203,8 +194,7 @@ class BlockTable:
self.block_table.fill_(0) self.block_table.fill_(0)
self.block_table.cpu.fill_(0) self.block_table.cpu.fill_(0)
def _convert_physical_to_logical_blocks( def _convert_physical_to_logical_blocks(self, physical_blocks: np.ndarray) -> np.ndarray:
self, physical_blocks: np.ndarray) -> np.ndarray:
"""Convert physical block IDs to logical block IDs.""" """Convert physical block IDs to logical block IDs."""
if not self.use_hybrid_blocks: if not self.use_hybrid_blocks:
return physical_blocks return physical_blocks
@@ -217,8 +207,7 @@ class BlockTable:
# [1*split_ratio, 1*split_ratio+1, ...] # [1*split_ratio, 1*split_ratio+1, ...]
# But we need to account for the fact that block 0 is special # But we need to account for the fact that block 0 is special
base_logical = phys_block * self.blocks_per_phys_block base_logical = phys_block * self.blocks_per_phys_block
logical_blocks.extend( logical_blocks.extend(range(base_logical, base_logical + self.blocks_per_phys_block))
range(base_logical, base_logical + self.blocks_per_phys_block))
return np.array(logical_blocks, dtype=np.int32) return np.array(logical_blocks, dtype=np.int32)
@@ -234,27 +223,25 @@ class BlockTable:
"""Returns the numpy array of the block table.""" """Returns the numpy array of the block table."""
return self.block_table.np return self.block_table.np
def _make_buffer(self, *size: int | torch.SymInt, def _make_buffer(self, *size: int | torch.SymInt, dtype: torch.dtype) -> CpuGpuBuffer:
dtype: torch.dtype) -> CpuGpuBuffer: return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
class MultiGroupBlockTable: class MultiGroupBlockTable:
"""The BlockTables for each KV cache group.""" """The BlockTables for each KV cache group."""
def __init__(self, def __init__(
max_num_reqs: int, self,
max_model_len: int, max_num_reqs: int,
max_num_batched_tokens: int, max_model_len: int,
pin_memory: bool, max_num_batched_tokens: int,
device: torch.device, pin_memory: bool,
block_sizes: list[int], device: torch.device,
num_speculative_tokens: int = 0, block_sizes: list[int],
kernel_sizes: Optional[list[list[int]]] = None, num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1) -> None: kernel_sizes: list[list[int]] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store # Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache, # (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req # so the block_size which used for calc max_num_blocks_per_req
@@ -274,24 +261,26 @@ class MultiGroupBlockTable:
kernel_sizes = kernel_sizes * len(block_sizes) kernel_sizes = kernel_sizes * len(block_sizes)
elif len(kernel_sizes) != len(block_sizes): elif len(kernel_sizes) != len(block_sizes):
raise ValueError( raise ValueError(
f"kernel_sizes length ({len(kernel_sizes)}) must match " f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})"
f"block_sizes length ({len(block_sizes)})") )
# Use zip to pair block_sizes with kernel_sizes one-to-one # Use zip to pair block_sizes with kernel_sizes one-to-one
self.block_tables = [ self.block_tables = [
BlockTable( BlockTable(
block_size, max_num_reqs, block_size,
max( max_num_reqs,
cdiv(max_model_len, max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens),
block_size * dcp_world_size * pcp_world_size), max_num_batched_tokens,
1 + num_speculative_tokens), max_num_batched_tokens, pin_memory,
pin_memory, device, kernel_size_list, device,
cp_kv_cache_interleave_size, num_speculative_tokens) kernel_size_list,
cp_kv_cache_interleave_size,
num_speculative_tokens,
)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
] ]
def append_row(self, block_ids: tuple[list[int], ...], def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx) block_table.append_row(block_ids[i], row_idx)
@@ -307,8 +296,7 @@ class MultiGroupBlockTable:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.swap_row(src, tgt) block_table.swap_row(src, tgt)
def compute_slot_mapping(self, req_indices: np.ndarray, def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
positions: np.ndarray) -> None:
for block_table in self.block_tables: for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions) block_table.compute_slot_mapping(req_indices, positions)