diff --git a/pyproject.toml b/pyproject.toml index 0baf2470..300492c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,10 +51,7 @@ line-length = 120 # Folder to be modified exclude = [ "tests/**", - # (7) - "vllm_ascend/quantization/**", - "vllm_ascend/sample/*.py", - "vllm_ascend/worker/block_table.py", + # (8) "vllm_ascend/ops/__init__.py", "vllm_ascend/ops/activation.py", @@ -66,6 +63,7 @@ exclude = [ "vllm_ascend/ops/vocab_parallel_embedding.py", "vllm_ascend/ops/weight_prefetch.py", "vllm_ascend/spec_decode/**", + # (10) "vllm_ascend/ops/*linear*.py", "vllm_ascend/worker/worker.py", @@ -76,6 +74,7 @@ exclude = [ "vllm_ascend/worker/v2/**", "vllm_ascend/worker/npu_input_batch.py", "vllm_ascend/ops/rotary_embedding.py", + # (11) "vllm_ascend/ops/fused_moe/**", ] diff --git a/vllm_ascend/quantization/__init__.py b/vllm_ascend/quantization/__init__.py index d5b31e33..1bf29125 100644 --- a/vllm_ascend/quantization/__init__.py +++ b/vllm_ascend/quantization/__init__.py @@ -29,6 +29,7 @@ Public API: # LLM-Compressor (compressed_tensors) quantization config from .compressed_tensors_config import AscendCompressedTensorsConfig + # ModelSlim quantization config from .modelslim_config import AscendModelSlimConfig diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 7896d1b4..ea138110 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -17,23 +17,20 @@ # """LLM-Compressor (compressed_tensors) quantization configuration for Ascend.""" -from typing import Any, Optional, Union, cast +from typing import Any, Optional, cast import torch -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, QuantizationType from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization import ( - QUANTIZATION_METHODS, register_quantization_config) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS, register_quantization_config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target, is_activation_quantization_format, - should_ignore_layer) + find_matched_target, + is_activation_quantization_format, + should_ignore_layer, +) from vllm.model_executor.models.utils import WeightsMapper from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD @@ -51,14 +48,13 @@ def _remove_quantization_method(): _remove_quantization_method() -QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, - "QuantizationArgs"]]] +QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, "QuantizationArgs"] | None] @register_quantization_config(COMPRESSED_TENSORS_METHOD) class AscendCompressedTensorsConfig(QuantizationConfig): """Config class for LLM-Compressor (compressed_tensors) quantization on Ascend. - + This class adapts the compressed_tensors format to work with Ascend's quantization implementations. """ @@ -68,7 +64,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): target_scheme_map: dict[str, Any], ignore: list[str], quant_format: str, - config: Optional[dict[str, Any]] = None, + config: dict[str, Any] | None = None, ): super().__init__() self.ignore = ignore @@ -86,8 +82,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - raise NotImplementedError( - "Ascend hardware dose not support \"get_min_capability\" feature.") + raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.') @classmethod def get_config_filenames(cls) -> list[str]: @@ -100,18 +95,15 @@ class AscendCompressedTensorsConfig(QuantizationConfig): targeting 'Linear' needs to also match FusedMoE modules. """ - if ("Linear" not in self.target_scheme_map - or "FusedMoE" in self.target_scheme_map): + if "Linear" not in self.target_scheme_map or "FusedMoE" in self.target_scheme_map: return self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"] @classmethod - def from_config(cls, config: dict[str, - Any]) -> "AscendCompressedTensorsConfig": + def from_config(cls, config: dict[str, Any]) -> "AscendCompressedTensorsConfig": ignore: list[str] = cast(list[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) - target_scheme_map = cls._quantization_scheme_map_from_config( - config=config) + target_scheme_map = cls._quantization_scheme_map_from_config(config=config) return cls( target_scheme_map=target_scheme_map, @@ -121,10 +113,9 @@ class AscendCompressedTensorsConfig(QuantizationConfig): ) @classmethod - def _quantization_scheme_map_from_config( - cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + def _quantization_scheme_map_from_config(cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: """Build target scheme map from config. - + :param config: The `quantization_config` dictionary from config.json :return: A dictionary mapping target layer names to their corresponding quantization_args for weights and input activations @@ -138,24 +129,22 @@ class AscendCompressedTensorsConfig(QuantizationConfig): targets = quant_config.get("targets") for target in targets: target_scheme_map[target] = {} - target_scheme_map[target][ - "weights"] = QuantizationArgs.model_validate( - quant_config.get("weights")) + target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(quant_config.get("weights")) target_scheme_map[target]["input_activations"] = None - target_scheme_map[target]["format"] = quant_config.get( - "format") + target_scheme_map[target]["format"] = quant_config.get("format") format = target_scheme_map[target].get("format") # If no per-config format defined, use global format in config act_quant_format = ( is_activation_quantization_format(format) - if format is not None else - is_activation_quantization_format(quant_format)) + if format is not None + else is_activation_quantization_format(quant_format) + ) input_activations = quant_config.get("input_activations") if act_quant_format and input_activations is not None: - target_scheme_map[target]["input_activations"] = ( - QuantizationArgs.model_validate( - quant_config.get("input_activations"))) + target_scheme_map[target]["input_activations"] = QuantizationArgs.model_validate( + quant_config.get("input_activations") + ) return target_scheme_map def get_quant_method( @@ -168,8 +157,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): if isinstance(layer, LinearBase): layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD # Get the scheme for this layer - linear_scheme = self._get_linear_scheme(layer=layer, - layer_name=prefix) + linear_scheme = self._get_linear_scheme(layer=layer, layer_name=prefix) # Return unquantized method if no scheme found if linear_scheme is None: @@ -177,14 +165,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig): # Store scheme on layer for reference (optional, for debugging) layer.scheme = linear_scheme - logger.info_once( - "Using the vLLM Ascend llmcompressor Quantization now!") + logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") return AscendLinearMethod(linear_scheme) if isinstance(layer, FusedMoE): # Delayed import to avoid circular import - from vllm_ascend.ops.fused_moe.fused_moe import \ - AscendUnquantizedFusedMoEMethod + from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD layer_name = prefix + ".0.gate_proj" @@ -197,24 +183,19 @@ class AscendCompressedTensorsConfig(QuantizationConfig): # Store scheme on layer for reference (optional, for debugging) layer.scheme = moe_scheme - logger.info_once( - "Using the vLLM Ascend llmcompressor Quantization now!") + logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") return AscendFusedMoEMethod(moe_scheme, layer.moe_config) return None - def _get_linear_scheme( - self, - layer: torch.nn.Module, - layer_name: Optional[str] = None) -> Optional[AscendLinearScheme]: + def _get_linear_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendLinearScheme | None: """Get the linear quantization scheme for a layer. - + Returns: An AscendLinearScheme instance, or None if the layer should use unquantized method. """ - weight_quant, input_quant, format = self._get_quant_args( - layer, layer_name) + weight_quant, input_quant, format = self._get_quant_args(layer, layer_name) if weight_quant is None: return None @@ -226,12 +207,9 @@ class AscendCompressedTensorsConfig(QuantizationConfig): ) return cast(AscendLinearScheme, scheme) - def _get_moe_scheme( - self, - layer: torch.nn.Module, - layer_name: Optional[str] = None) -> Optional[AscendMoEScheme]: + def _get_moe_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendMoEScheme | None: """Get the MoE quantization scheme for a layer. - + Returns: An AscendMoEScheme instance, or None if the layer should use unquantized method. @@ -239,8 +217,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): # Add FusedMoE to target scheme map if needed self._add_fused_moe_to_target_scheme_map() - weight_quant, input_quant, format = self._get_quant_args( - layer, layer_name) + weight_quant, input_quant, format = self._get_quant_args(layer, layer_name) if weight_quant is None: return None @@ -253,13 +230,10 @@ class AscendCompressedTensorsConfig(QuantizationConfig): return cast(AscendMoEScheme, scheme) def _get_quant_args( - self, - layer: torch.nn.Module, - layer_name: Optional[str] = None - ) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"], - Optional[str]]: + self, layer: torch.nn.Module, layer_name: str | None = None + ) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"], str | None]: """Extract quantization arguments for a layer. - + compressed-tensors supports non uniform in the following way: targets of config_groups: There can be N config_groups which each @@ -269,7 +243,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): Detect whether a layer_name is found in any target and use the quantization scheme corresponding to the matched target. - + Returns: A tuple of (weight_quant, input_quant, format). weight_quant is None if the layer should use unquantized method. @@ -284,16 +258,16 @@ class AscendCompressedTensorsConfig(QuantizationConfig): format = scheme_dict.get("format") if weight_quant is None: - logger.warning_once("Acceleration for non-quantized schemes is " - "not supported by Compressed Tensors. " - "Falling back to UnquantizedLinearMethod") + logger.warning_once( + "Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod" + ) return weight_quant, input_quant, format def get_scheme_dict( - self, - layer: torch.nn.Module, - layer_name: str | None = None + self, layer: torch.nn.Module, layer_name: str | None = None ) -> dict[str, QuantizationArgs | str | None] | None: """ Extract the QuantizationArgs for a given layer. @@ -305,9 +279,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): "format": str | None } | None """ - if should_ignore_layer(layer_name, - ignore=self.ignore, - fused_mapping=self.packed_modules_mapping): + if should_ignore_layer(layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping): return None if self.target_scheme_map: @@ -328,17 +300,17 @@ class AscendCompressedTensorsConfig(QuantizationConfig): self, weight_quant: "QuantizationArgs", input_quant: Optional["QuantizationArgs"], - format: Optional[str], + format: str | None, layer_type: str, - ) -> Union[AscendLinearScheme, AscendMoEScheme]: + ) -> AscendLinearScheme | AscendMoEScheme: """Create the appropriate Ascend scheme based on quantization args and layer type. - + Args: weight_quant: Weight quantization arguments. input_quant: Input activation quantization arguments. format: Per-layer format, if defined. layer_type: Type of layer ("linear" or "moe"). - + Returns: An instance of the appropriate Ascend quantization scheme. """ @@ -352,7 +324,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig): if scheme_cls is None: raise NotImplementedError( f"No compressed-tensors compatible scheme was found for " - f"quant_type={quant_type}, layer_type={layer_type}.") + f"quant_type={quant_type}, layer_type={layer_type}." + ) return scheme_cls() @@ -360,15 +333,15 @@ class AscendCompressedTensorsConfig(QuantizationConfig): self, weight_quant: "QuantizationArgs", input_quant: Optional["QuantizationArgs"], - format: Optional[str], + format: str | None, ) -> str: """Detect the quantization type from quantization arguments. - + Args: weight_quant: Weight quantization arguments. input_quant: Input activation quantization arguments. format: Per-layer format, if defined. - + Returns: A string representing the quantization type (e.g., "W8A8", "W8A8_DYNAMIC"). """ @@ -389,16 +362,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig): if self._is_w4a16(weight_quant, input_quant): return "W4A16" - raise NotImplementedError( - "No compressed-tensors compatible quantization type was found.") + raise NotImplementedError("No compressed-tensors compatible quantization type was found.") - def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", - input_quant: "QuantizationArgs") -> bool: + def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 - weight_strategy = ( - weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_tensor = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TENSOR.value) + weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value + is_tensor = weight_strategy and input_quant.strategy == QuantizationStrategy.TENSOR.value is_static = not weight_quant.dynamic and not input_quant.dynamic is_symmetric = weight_quant.symmetric and input_quant.symmetric @@ -406,28 +375,24 @@ class AscendCompressedTensorsConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_8_bits and is_tensor and is_symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs", - input_quant: "QuantizationArgs") -> bool: + def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 - weight_strategy = ( - weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value + is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value is_dynamic = not weight_quant.dynamic and input_quant.dynamic is_symmetric = weight_quant.symmetric and input_quant.symmetric # Only symmetric input quantization supported. # Only symmetric weight quantization supported. return is_8_bits and is_token and is_symmetric and is_dynamic - - def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + + def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs) -> bool: is_4_bits = weight_quant.num_bits == 4 is_8_bits = input_quant.num_bits == 8 - weight_strategy = ( - weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or (weight_quant.strategy == QuantizationStrategy.GROUP.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + weight_strategy = (weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or ( + weight_quant.strategy == QuantizationStrategy.GROUP.value + ) + is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value is_dynamic = not weight_quant.dynamic and input_quant.dynamic is_symmetric = weight_quant.symmetric and input_quant.symmetric @@ -435,7 +400,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): assert self.quant_description is not None, "quant_description should not be None" if weight_strategy: self.quant_description["group_size"] = weight_quant.group_size if weight_quant.group_size else 0 - + self.quant_description["version"] = "0" self.quant_description["ascend_quant_method"] = COMPRESSED_TENSORS_METHOD self.quant_description["weight_strategy"] = str(weight_quant.strategy) @@ -444,8 +409,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig): # Only symmetric weight quantization supported. return is_4_bits and is_8_bits and is_token and is_symmetric and is_dynamic - def _is_w4a16(self, weight_quant: "QuantizationArgs", - input_quant: Optional["QuantizationArgs"]) -> bool: + def _is_w4a16(self, weight_quant: "QuantizationArgs", input_quant: Optional["QuantizationArgs"]) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -456,12 +420,11 @@ class AscendCompressedTensorsConfig(QuantizationConfig): input_quant_none = input_quant is None is_4_bits = weight_quant.num_bits == 4 - is_group = (weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_group = weight_quant.strategy == QuantizationStrategy.GROUP.value is_static = not weight_quant.dynamic return input_quant_none and is_4_bits and is_group and is_static def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - self.target_scheme_map = hf_to_vllm_mapper.apply_dict( - self.target_scheme_map) + self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map) self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py index b882e6d8..82056bec 100644 --- a/vllm_ascend/quantization/method_adapters.py +++ b/vllm_ascend/quantization/method_adapters.py @@ -16,27 +16,22 @@ # This file is a part of the vllm-ascend project. # -from typing import Callable, List, Optional +from collections.abc import Callable import torch from vllm.distributed import get_tensor_model_parallel_rank -from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase, FusedMoeWeightScaleSupported from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig -from vllm.model_executor.layers.linear import (LinearMethodBase, - RowParallelLinear) +from vllm.model_executor.layers.linear import LinearMethodBase, RowParallelLinear from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, - get_mlp_tp_group, - get_otp_group) +from vllm_ascend.distributed.parallel_state import get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable -from .methods import (AscendAttentionScheme, AscendLinearScheme, - AscendMoEScheme, is_mx_quant_type) +from .methods import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, is_mx_quant_type class AscendLinearMethod(LinearMethodBase): @@ -56,7 +51,7 @@ class AscendLinearMethod(LinearMethodBase): self, layer: torch.nn.Module, input_size_per_partition: int, - output_partition_sizes: List[int], + output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -65,9 +60,7 @@ class AscendLinearMethod(LinearMethodBase): output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") - weight_dict = self.quant_method.get_weight(input_size_per_partition, - output_size_per_partition, - params_dtype) + weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, params_dtype) # Extract packing information (if present) packed_dim = weight_dict.pop("_packed_dim", None) @@ -79,25 +72,20 @@ class AscendLinearMethod(LinearMethodBase): # Set packing attributes if the weight is packed if packed_dim is not None and packed_factor is not None: - set_weight_attrs(param, { - "packed_dim": packed_dim, - "packed_factor": packed_factor - }) + set_weight_attrs(param, {"packed_dim": packed_dim, "packed_factor": packed_factor}) layer.register_parameter(weight_name, param) set_weight_attrs(param, extra_weight_attrs) pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) for pertensor_name, pertensor_param in pertensor_dict.items(): - param = PerTensorScaleParameter(data=pertensor_param, - weight_loader=weight_loader) + param = PerTensorScaleParameter(data=pertensor_param, weight_loader=weight_loader) # disable warning param.ignore_warning = True layer.register_parameter(pertensor_name, param) param.weight_loader = extra_weight_attrs.get("weight_loader") - perchannel_dict = self.quant_method.get_perchannel_param( - output_size_per_partition, params_dtype) + perchannel_dict = self.quant_method.get_perchannel_param(output_size_per_partition, params_dtype) for perchannel_name, perchannel_param in perchannel_dict.items(): param = torch.nn.Parameter(perchannel_param, requires_grad=False) set_weight_attrs(param, {"output_dim": 0}) @@ -107,22 +95,22 @@ class AscendLinearMethod(LinearMethodBase): # NOTE: In w4a8 quantization implementation, # for down_proj and o_proj scale_bias shape is [output_size, 16], # others are [output_size, 1] - layer_type = "row" if isinstance(layer, - RowParallelLinear) else "others" + layer_type = "row" if isinstance(layer, RowParallelLinear) else "others" pergroup_dict = self.quant_method.get_pergroup_param( - input_size_per_partition, - output_size_per_partition, - params_dtype, - layer_type=layer_type) + input_size_per_partition, output_size_per_partition, params_dtype, layer_type=layer_type + ) for pergroup_name, pergroup_param in pergroup_dict.items(): param = torch.nn.Parameter(pergroup_param, requires_grad=False) set_weight_attrs(param, {"output_dim": 0}) layer.register_parameter(pergroup_name, param) set_weight_attrs(param, extra_weight_attrs) - if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \ - or is_mx_quant_type(self.quant_method): - setattr(param, "input_dim", 1) + if ( + "weight_scale_second" in pergroup_name + or "weight_offset_second" in pergroup_name + or is_mx_quant_type(self.quant_method) + ): + param.input_dim = 1 param.input_dim = 1 def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -133,17 +121,15 @@ class AscendLinearMethod(LinearMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, ) -> torch.Tensor: if isinstance(layer, RowParallelLinear): if layer.prefix.find("o_proj") != -1 and oproj_tp_enable(): tp_rank = get_otp_group().rank_in_group elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): tp_rank = get_mlp_tp_group().rank_in_group - elif (layer.prefix.find("o_proj") != -1 or - layer.prefix.find("out_proj") != -1) and flashcomm2_enable(): - if get_ascend_config( - ).flashcomm2_oproj_tensor_parallel_size == 1: + elif (layer.prefix.find("o_proj") != -1 or layer.prefix.find("out_proj") != -1) and flashcomm2_enable(): + if get_ascend_config().flashcomm2_oproj_tensor_parallel_size == 1: tp_rank = 0 else: tp_rank = get_flashcomm2_otp_group().rank_in_group @@ -175,11 +161,19 @@ class AscendKVCacheMethod(BaseKVCacheMethod): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.quant_method.process_weights_after_loading(layer) - def apply(self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, - attn_type, scale, output) -> torch.Tensor: - return self.quant_method.apply(layer, query, key, value, kv_cache, - attn_metadata, attn_type, scale, output) + def apply( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache, + attn_metadata, + attn_type, + scale, + output, + ) -> torch.Tensor: + return self.quant_method.apply(layer, query, key, value, kv_cache, attn_metadata, attn_type, scale, output) class AscendFusedMoEMethod(FusedMoEMethodBase): @@ -192,8 +186,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): moe_config: The FusedMoE configuration. """ - def __init__(self, scheme: AscendMoEScheme, - moe_config: FusedMoEConfig) -> None: + def __init__(self, scheme: AscendMoEScheme, moe_config: FusedMoEConfig) -> None: super().__init__(moe_config) self.quant_method = scheme @@ -207,30 +200,28 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): **extra_weight_attrs, ) -> None: weight_param = self.quant_method.get_weight( - num_experts, intermediate_size_per_partition, hidden_size, - params_dtype) + num_experts, intermediate_size_per_partition, hidden_size, params_dtype + ) for param_key, param_value in weight_param.items(): param = torch.nn.Parameter(param_value, requires_grad=False) layer.register_parameter(param_key, param) set_weight_attrs(param, extra_weight_attrs) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - per_group_param = [ - "weight_scale_second", "weight_offset_second", "scale_bias" - ] + ["weight_scale", "weight_offset"] if hasattr( - self.quant_method, - "group_size") and self.quant_method.group_size > 0 else [] + extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + per_group_param = ( + ["weight_scale_second", "weight_offset_second", "scale_bias"] + ["weight_scale", "weight_offset"] + if hasattr(self.quant_method, "group_size") and self.quant_method.group_size > 0 + else [] + ) dynamic_quant_param = self.quant_method.get_dynamic_quant_param( - num_experts, intermediate_size_per_partition, hidden_size, - params_dtype) + num_experts, intermediate_size_per_partition, hidden_size, params_dtype + ) for param_key, param_value in dynamic_quant_param.items(): param = torch.nn.Parameter(param_value, requires_grad=False) layer.register_parameter(param_key, param) set_weight_attrs(param, extra_weight_attrs) if any(fields in param_key for fields in per_group_param): - setattr(param, "quant_method", - FusedMoeWeightScaleSupported.GROUP.value) + param.quant_method = FusedMoeWeightScaleSupported.GROUP.value def apply( self, @@ -241,25 +232,40 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): renormalize: bool, use_grouped_topk: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - log2phy: Optional[torch.Tensor] = None, + log2phy: torch.Tensor | None = None, global_redundant_expert_num=0, **kwargs, ) -> torch.Tensor: return self.quant_method.apply( - layer, x, router_logits, top_k, renormalize, use_grouped_topk, - global_num_experts, expert_map, topk_group, num_expert_group, - custom_routing_function, scoring_func, routed_scaling_factor, - e_score_correction_bias, is_prefill, enable_force_load_balance, - log2phy, global_redundant_expert_num, **kwargs) + layer, + x, + router_logits, + top_k, + renormalize, + use_grouped_topk, + global_num_experts, + expert_map, + topk_group, + num_expert_group, + custom_routing_function, + scoring_func, + routed_scaling_factor, + e_score_correction_bias, + is_prefill, + enable_force_load_balance, + log2phy, + global_redundant_expert_num, + **kwargs, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): @@ -276,7 +282,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): class AscendEmbeddingMethod(AscendLinearMethod): """Embedding method for Ascend quantization. - + This is essentially the same as AscendLinearMethod, just with a different name for clarity when used with VocabParallelEmbedding layers. diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index bfbfeae1..38840295 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -21,7 +21,7 @@ Schemes are automatically registered via the @register_scheme decorator. Usage: from vllm_ascend.quantization.methods import get_scheme_class - + # Get a scheme class by quant_type and layer_type scheme_cls = get_scheme_class("W8A8_DYNAMIC", "linear") scheme = scheme_cls() @@ -30,28 +30,26 @@ Usage: from typing import Any # Import base classes -from .base import (AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, - QuantType) +from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType + # Import registry functions from .registry import get_scheme_class, register_scheme + # Import all scheme classes for external access from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod -from .w4a8 import (AscendW4A8DynamicFusedMoEMethod, - AscendW4A8DynamicLinearMethod) +from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod from .w4a16 import AscendW4A16FusedMoEMethod -from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, - AscendW8A8DynamicLinearMethod) +from .w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod -from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, - AscendW8A8PDMixLinearMethod) +from .w8a8_pdmix import AscendW8A8PDMixFusedMoeMethod, AscendW8A8PDMixLinearMethod from .w8a8_static import AscendW8A8LinearMethod from .w8a16 import AscendW8A16LinearMethod def is_mx_quant_type(instance: Any) -> bool: """Checks if the quantization method is a microscaling (MX) type.""" - MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, ) + MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod,) return isinstance(instance, MX_QUANT_TYPES) diff --git a/vllm_ascend/quantization/methods/base.py b/vllm_ascend/quantization/methods/base.py index 9bcec5c2..c525167a 100644 --- a/vllm_ascend/quantization/methods/base.py +++ b/vllm_ascend/quantization/methods/base.py @@ -17,14 +17,16 @@ """Abstract base classes for Ascend quantization schemes.""" from abc import ABC, abstractmethod +from collections.abc import Callable from enum import Enum -from typing import Any, Callable, Dict, Optional +from typing import Any import torch class QuantType(Enum): """Quantization type enum for MoE schemes.""" + NONE = 0 W8A8 = 1 W4A8 = 2 @@ -32,84 +34,78 @@ class QuantType(Enum): class AscendLinearScheme(ABC): """Base class for all linear quantization schemes. - + Subclasses must implement get_weight() and apply() methods. Other methods have default implementations that return empty dicts or do nothing. """ @abstractmethod - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: """Return weight tensor specifications. - + Args: input_size: Input dimension of the linear layer. output_size: Output dimension of the linear layer. params_dtype: Data type for parameters. - + Returns: Dictionary mapping parameter names to empty tensors with the correct shape and dtype. """ ... - def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]: """Return per-tensor parameter specifications (e.g., input_scale). - + Args: params_dtype: Data type for parameters. - + Returns: Dictionary mapping parameter names to empty tensors. """ return {} - def get_perchannel_param(self, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: """Return per-channel parameter specifications (e.g., weight_scale). - + Args: output_size: Output dimension of the linear layer. params_dtype: Data type for parameters. - + Returns: Dictionary mapping parameter names to empty tensors. """ return {} - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: + def get_pergroup_param( + self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None + ) -> dict[str, Any]: """Return per-group parameter specifications. - + Args: input_size: Input dimension of the linear layer. output_size: Output dimension of the linear layer. params_dtype: Data type for parameters. layer_type: Type of layer (e.g., "row" for RowParallelLinear). - + Returns: Dictionary mapping parameter names to empty tensors. """ return {} @abstractmethod - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0) -> torch.Tensor: + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, tp_rank: int | None = 0 + ) -> torch.Tensor: """Forward computation. - + Args: layer: The linear layer module. x: Input tensor. bias: Optional bias tensor. tp_rank: Tensor parallel rank. - + Returns: Output tensor after quantized linear operation. """ @@ -117,42 +113,51 @@ class AscendLinearScheme(ABC): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """Post-loading weight processing (transpose, format conversion, etc.). - + Args: layer: The linear layer module. """ - pass + return class AscendAttentionScheme(ABC): """Base class for all attention quantization schemes. - + Subclasses must implement apply() method. Other methods have default implementations. """ def create_weights(self, layer: torch.nn.Module) -> None: """Create weights for attention quantization. - + Args: layer: The attention layer module. """ - pass + return def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """Post-loading weight processing for attention layer. - + Args: layer: The attention layer module. """ - pass + return @abstractmethod - def apply(self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, - attn_type, scale, output) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache, + attn_metadata, + attn_type, + scale, + output, + ) -> torch.Tensor: """Forward computation for attention layer. - + Args: layer: The attention layer module. query: Query tensor. @@ -163,7 +168,7 @@ class AscendAttentionScheme(ABC): attn_type: Attention type. scale: Scale factor. output: Output tensor. - + Returns: Output tensor after attention computation. """ @@ -172,10 +177,10 @@ class AscendAttentionScheme(ABC): class AscendMoEScheme(ABC): """Base class for all MoE quantization schemes. - + Subclasses must implement get_weight(), get_dynamic_quant_param(), and apply() methods. - + Attributes: quant_type: The quantization type for this scheme. Subclasses should override this class attribute to declare their quant type. @@ -185,35 +190,34 @@ class AscendMoEScheme(ABC): quant_type: QuantType = QuantType.NONE @abstractmethod - def get_weight(self, num_experts: int, - intermediate_size_per_partition: int, hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_weight( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: """Return weight tensor specifications for MoE layer. - + Args: num_experts: Number of experts. intermediate_size_per_partition: Intermediate size per partition. hidden_sizes: Hidden dimension size. params_dtype: Data type for parameters. - + Returns: Dictionary mapping parameter names to empty tensors. """ ... @abstractmethod - def get_dynamic_quant_param(self, num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_dynamic_quant_param( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: """Return dynamic quantization parameters for MoE layer. - + Args: num_experts: Number of experts. intermediate_size_per_partition: Intermediate size per partition. hidden_sizes: Hidden dimension size. params_dtype: Data type for parameters. - + Returns: Dictionary mapping parameter names to empty tensors. """ @@ -229,21 +233,21 @@ class AscendMoEScheme(ABC): renormalize: bool, use_grouped_topk: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - log2phy: Optional[torch.Tensor] = None, + log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, **kwargs, ) -> torch.Tensor: """Forward computation for MoE layer. - + Args: layer: The MoE layer module. x: Input hidden states. @@ -264,7 +268,7 @@ class AscendMoEScheme(ABC): log2phy: Logical to physical expert mapping. global_redundant_expert_num: Number of redundant experts. **kwargs: Additional keyword arguments. - + Returns: Output tensor after MoE computation. """ @@ -272,8 +276,8 @@ class AscendMoEScheme(ABC): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """Post-loading weight processing for MoE layer. - + Args: layer: The MoE layer module. """ - pass + return diff --git a/vllm_ascend/quantization/methods/registry.py b/vllm_ascend/quantization/methods/registry.py index 59766402..b8eee21a 100644 --- a/vllm_ascend/quantization/methods/registry.py +++ b/vllm_ascend/quantization/methods/registry.py @@ -15,47 +15,47 @@ # limitations under the License. # -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any # Registry: maps (quant_type, layer_type) -> SchemeClass -_SCHEME_REGISTRY: Dict[Tuple[str, str], Type[Any]] = {} +_SCHEME_REGISTRY: dict[tuple[str, str], type[Any]] = {} def register_scheme(quant_type: str, layer_type: str): """Decorator to register a quantization scheme. - + Args: quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC"). layer_type: Layer type (e.g., "linear", "moe"). - + Returns: Decorator function that registers the class. - + Example: @register_scheme("W8A8_DYNAMIC", "linear") class W8A8DynamicLinearScheme(AscendLinearScheme): ... """ - def decorator(cls: Type[Any]) -> Type[Any]: + def decorator(cls: type[Any]) -> type[Any]: key = (quant_type, layer_type) if key in _SCHEME_REGISTRY: raise ValueError( - f"Scheme already registered for {quant_type}/{layer_type}: " - f"{_SCHEME_REGISTRY[key].__name__}") + f"Scheme already registered for {quant_type}/{layer_type}: {_SCHEME_REGISTRY[key].__name__}" + ) _SCHEME_REGISTRY[key] = cls return cls return decorator -def get_scheme_class(quant_type: str, layer_type: str) -> Optional[Type[Any]]: +def get_scheme_class(quant_type: str, layer_type: str) -> type[Any] | None: """Get scheme class for given quant_type and layer_type. - + Args: quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC"). layer_type: Layer type (e.g., "linear", "moe"). - + Returns: The registered scheme class, or None if not found. """ diff --git a/vllm_ascend/quantization/methods/w4a16.py b/vllm_ascend/quantization/methods/w4a16.py index 92d5fdf4..f30ff88e 100644 --- a/vllm_ascend/quantization/methods/w4a16.py +++ b/vllm_ascend/quantization/methods/w4a16.py @@ -15,7 +15,8 @@ # limitations under the License. # -from typing import Any, Callable, Dict, Optional +from collections.abc import Callable +from typing import Any import torch import torch_npu @@ -56,8 +57,7 @@ def unpack_from_int32( dtype=torch.int32, ) for i in range(pack_factor): - unpacked_weight[:, i::pack_factor] = (weight >> - (num_bits * i)) & mask + unpacked_weight[:, i::pack_factor] = (weight >> (num_bits * i)) & mask original_row_size = int(shape[1]) unpacked_weight = unpacked_weight[:, :original_row_size] else: @@ -67,8 +67,7 @@ def unpack_from_int32( dtype=torch.int32, ) for i in range(pack_factor): - unpacked_weight[i::pack_factor, :] = (weight >> - (num_bits * i)) & mask + unpacked_weight[i::pack_factor, :] = (weight >> (num_bits * i)) & mask original_row_size = int(shape[0]) unpacked_weight = unpacked_weight[:original_row_size, :] @@ -84,22 +83,17 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor: :param weight: The 3D tensor to pack, must be int8 or int32 dtype :return: Packed tensor with int32 dtype optimized for storage """ - assert weight.dim( - ) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." - assert weight.dtype in [ - torch.int8, torch.int32 - ], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}." + assert weight.dim() == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." + assert weight.dtype in [torch.int8, torch.int32], ( + f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}." + ) if weight.dtype == torch.int32: - assert weight.shape[ - -1] % 8 == 0, "the last dim of weight needs to be divided by 8." - packed_weight = torch_npu.npu_convert_weight_to_int4pack( - weight.flatten(0, 1)) - packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], - -1) + assert weight.shape[-1] % 8 == 0, "the last dim of weight needs to be divided by 8." + packed_weight = torch_npu.npu_convert_weight_to_int4pack(weight.flatten(0, 1)) + packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], -1) else: - assert weight.shape[ - -1] % 4 == 0, "the last dim of weight needs to be divided by 4." + assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4." packed_weight = weight.view(torch.int32).contiguous() return packed_weight @@ -115,8 +109,7 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32 vllm_config = get_current_vllm_config() - self.group_size = vllm_config.quant_config.quant_description.get( - "group_size", 32) + self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32) self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb def get_weight( @@ -125,22 +118,23 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: - assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}" - assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}" + ) -> dict[str, Any]: + assert intermediate_size_per_partition % self.pack_factor == 0, ( + f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} " + f"can be divided by `pack_factor` {self.pack_factor}" + ) + assert hidden_sizes % self.pack_factor == 0, ( + f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}" + ) param_dict = {} param_dict["w13_weight_packed"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.pack_factor, - dtype=torch.int32) + num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.pack_factor, dtype=torch.int32 + ) param_dict["w2_weight_packed"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.pack_factor, - dtype=torch.int32) + num_experts, hidden_sizes, intermediate_size_per_partition // self.pack_factor, dtype=torch.int32 + ) return param_dict @@ -150,38 +144,31 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: - assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}" - assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}" + ) -> dict[str, Any]: + assert intermediate_size_per_partition % self.group_size == 0, ( + f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} " + f"can be divided by `group_size` {self.group_size}" + ) + assert hidden_sizes % self.group_size == 0, ( + f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}" + ) param_dict = {} param_dict["w13_weight_scale"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.bfloat16) + num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.bfloat16 + ) param_dict["w2_weight_scale"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.bfloat16) - param_dict["w13_weight_shape"] = torch.empty(num_experts, - 2, - dtype=torch.int32) - param_dict["w2_weight_shape"] = torch.empty(num_experts, - 2, - dtype=torch.int32) + num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.bfloat16 + ) + param_dict["w13_weight_shape"] = torch.empty(num_experts, 2, dtype=torch.int32) + param_dict["w2_weight_shape"] = torch.empty(num_experts, 2, dtype=torch.int32) param_dict["w13_weight_offset"] = torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.bfloat16) + num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.bfloat16 + ) param_dict["w2_weight_offset"] = torch.zeros( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.bfloat16) + num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.bfloat16 + ) return param_dict @@ -194,21 +181,22 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): renormalize: bool, use_grouped_topk: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, is_prefill: bool = True, enable_force_load_balance: bool = True, - log2phy: Optional[torch.Tensor] = None, + log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, **kwargs, ) -> torch.Tensor: - assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( + "Number of global experts mismatch (excluding redundancy)" + ) topk_weights, topk_ids = select_experts( hidden_states=x, @@ -221,7 +209,8 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) + global_num_experts=global_num_experts, + ) topk_ids = topk_ids.to(torch.int32) topk_weights = topk_weights.to(x.dtype) @@ -241,38 +230,40 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): expert_map=expert_map, log2phy=log2phy, dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask", None)) + mc2_mask=kwargs.get("mc2_mask"), + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.transpose_weight: w13_shape = layer.w13_weight_packed.data.shape w2_shape = layer.w2_weight_packed.data.shape - unpacked_w13_weight = (unpack_from_int32( - layer.w13_weight_packed.data.flatten(0, 1), - torch.Size([ - w13_shape[0] * w13_shape[1], - w13_shape[2] * self.pack_factor - ]), - self.num_bits, - ).view(w13_shape[0], w13_shape[1], - -1).transpose(1, 2).contiguous().int()) - unpacked_w2_weight = (unpack_from_int32( - layer.w2_weight_packed.data.flatten(0, 1), - torch.Size([ - w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor - ]), - self.num_bits, - ).view(w2_shape[0], w2_shape[1], - -1).transpose(1, 2).contiguous().int()) + unpacked_w13_weight = ( + unpack_from_int32( + layer.w13_weight_packed.data.flatten(0, 1), + torch.Size([w13_shape[0] * w13_shape[1], w13_shape[2] * self.pack_factor]), + self.num_bits, + ) + .view(w13_shape[0], w13_shape[1], -1) + .transpose(1, 2) + .contiguous() + .int() + ) + unpacked_w2_weight = ( + unpack_from_int32( + layer.w2_weight_packed.data.flatten(0, 1), + torch.Size([w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor]), + self.num_bits, + ) + .view(w2_shape[0], w2_shape[1], -1) + .transpose(1, 2) + .contiguous() + .int() + ) layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight) layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight) - layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( - 1, 2).contiguous() + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2).contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2).contiguous() - layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose( - 1, 2).contiguous() - layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose( - 1, 2).contiguous() + layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(1, 2).contiguous() + layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(1, 2).contiguous() diff --git a/vllm_ascend/quantization/methods/w4a4_flatquant.py b/vllm_ascend/quantization/methods/w4a4_flatquant.py index 773709fb..1c830c41 100644 --- a/vllm_ascend/quantization/methods/w4a4_flatquant.py +++ b/vllm_ascend/quantization/methods/w4a4_flatquant.py @@ -16,7 +16,7 @@ # import math -from typing import Any, Dict, Optional, Tuple +from typing import Any import torch import torch_npu @@ -31,8 +31,7 @@ def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor: """Pack int4 weights for NPU.""" original_device = weight_tensor.device weight_tensor_npu = weight_tensor.npu() - weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack( - weight_tensor_npu.to(torch.int32), inner_k_tiles=1) + weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(weight_tensor_npu.to(torch.int32), inner_k_tiles=1) return weight_int4_packed.to(original_device) @@ -58,22 +57,14 @@ def batched_kronecker_quant( left_trans: torch.Tensor, right_trans: torch.Tensor, clip_ratio: float, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Batched Kronecker quantization with batch size limit handling.""" batch_tokens = x.shape[0] if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE: - return torch_npu.npu_kronecker_quant(x, - left_trans, - right_trans, - clip_ratio=clip_ratio, - dst_dtype=torch.int32) + return torch_npu.npu_kronecker_quant(x, left_trans, right_trans, clip_ratio=clip_ratio, dst_dtype=torch.int32) x_chunks = torch.split(x, KRONECKER_QUANT_MAX_BATCH_SIZE, dim=0) processed_chunks = [ - torch_npu.npu_kronecker_quant(chunk, - left_trans, - right_trans, - clip_ratio=clip_ratio, - dst_dtype=torch.int32) + torch_npu.npu_kronecker_quant(chunk, left_trans, right_trans, clip_ratio=clip_ratio, dst_dtype=torch.int32) for chunk in x_chunks ] quantized_list, scale_list = zip(*processed_chunks) @@ -85,39 +76,32 @@ def batched_kronecker_quant( @register_scheme("W4A4_FLATQUANT_DYNAMIC", "linear") class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme): """Linear method for Ascend W4A4_FLATQUANT_DYNAMIC. - + This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization. - Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8 and packed to int32 during loading - - Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for distribution smoothing - - Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded from external weights + - Activation: 4-bit dynamic quantization with FlatQuant transform matrices (left_trans, right_trans) for + distribution smoothing + - Parameters: clip_ratio for controlling quantization clipping, weight_offset for asymmetric quantization, loaded + from external weights """ + input_size = 0 def __init__(self): self.sym = True - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: if input_size % 8 != 0: - raise ValueError( - f"input_size ({input_size}) must be divisible by 8 for int4 packing" - ) + raise ValueError(f"input_size ({input_size}) must be divisible by 8 for int4 packing") AscendW4A4FlatQuantDynamicLinearMethod.input_size = input_size - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict - def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]: params_dict = {} - left_trans_dim, right_trans_dim = get_decompose_dim( - AscendW4A4FlatQuantDynamicLinearMethod.input_size) - params_dict["left_trans"] = torch.empty(left_trans_dim, - left_trans_dim, - dtype=params_dtype) - params_dict["right_trans"] = torch.empty(right_trans_dim, - right_trans_dim, - dtype=params_dtype) + left_trans_dim, right_trans_dim = get_decompose_dim(AscendW4A4FlatQuantDynamicLinearMethod.input_size) + params_dict["left_trans"] = torch.empty(left_trans_dim, left_trans_dim, dtype=params_dtype) + params_dict["right_trans"] = torch.empty(right_trans_dim, right_trans_dim, dtype=params_dtype) params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32) return params_dict @@ -125,22 +109,18 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme): self, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=torch.float32) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=torch.float32) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32) return params_dict def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, ) -> torch.Tensor: original_dtype = x.dtype input_shape = x.shape @@ -156,18 +136,18 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme): right_trans_matched = layer.right_trans.to(original_dtype) x_reshaped = x.view(-1, left_dim, right_dim) x_quantized_int4, activation_scale = batched_kronecker_quant( - x_reshaped, left_trans_matched, right_trans_matched, - layer.aclnn_clip_ratio) - x_quantized_reshaped = x_quantized_int4.view(-1, - left_dim * right_dim // 8) + x_reshaped, left_trans_matched, right_trans_matched, layer.aclnn_clip_ratio + ) + x_quantized_reshaped = x_quantized_int4.view(-1, left_dim * right_dim // 8) pertoken_scale = activation_scale.view(-1).to(torch.float32) - output = torch_npu.npu_quant_matmul(x_quantized_reshaped, - layer.weight_packed.t(), - layer.weight_scale.view(-1).to( - torch.float32), - pertoken_scale=pertoken_scale, - bias=None, - output_dtype=original_dtype) + output = torch_npu.npu_quant_matmul( + x_quantized_reshaped, + layer.weight_packed.t(), + layer.weight_scale.view(-1).to(torch.float32), + pertoken_scale=pertoken_scale, + bias=None, + output_dtype=original_dtype, + ) output = output.view(*input_shape[:-1], -1) if bias is not None: output = output + bias.to(original_dtype) @@ -176,15 +156,11 @@ class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme): def process_weights_after_loading(self, layer): # NOTE: Currently, w4a4 can't support weight nz weight_packed = pack_int4_weights(layer.weight.data) - layer.register_parameter( - 'weight_packed', - torch.nn.Parameter(weight_packed, requires_grad=False)) + layer.register_parameter("weight_packed", torch.nn.Parameter(weight_packed, requires_grad=False)) del layer.weight layer.weight_scale.data = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.to(torch.float32) - layer.left_trans = torch.nn.Parameter( - layer.left_trans.data.t().contiguous()) + layer.left_trans = torch.nn.Parameter(layer.left_trans.data.t().contiguous()) layer.right_trans = torch.nn.Parameter(layer.right_trans.data) - layer.clip_ratio = torch.nn.Parameter( - layer.clip_ratio.data.to(torch.float32)) + layer.clip_ratio = torch.nn.Parameter(layer.clip_ratio.data.to(torch.float32)) layer.aclnn_clip_ratio = layer.clip_ratio.item() diff --git a/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py b/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py index 3fc83b25..204cf95e 100644 --- a/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py +++ b/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any import torch import torch_npu @@ -27,7 +27,7 @@ from .registry import register_scheme @register_scheme("W4A4_DYNAMIC", "linear") class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): """Linear method for Ascend W4A4_DYNAMIC. - + This class implements W4A4 quantization with LAOS approach and dynamic activation quantization. - Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8. - Activation: 4-bit dynamic quantization. @@ -37,7 +37,7 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): self.transpose_weight = True self.rotation_type = None - def set_rotation_config(self, prefix: str, metadata: Dict) -> Optional[str]: + def set_rotation_config(self, prefix: str, metadata: dict) -> str | None: """Set rotation config based on prefix and metadata.""" layer_idx = prefix.split(".")[2] if prefix.endswith("o_proj"): @@ -50,34 +50,22 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): return "kronecker_rotation" return None - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict - def get_perchannel_param(self, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=torch.float32) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=torch.float32) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32) if self.rotation_type == "heads_rotation": - params_dict["heads_rotation"] = torch.zeros((64, 64), - dtype=torch.float32) + params_dict["heads_rotation"] = torch.zeros((64, 64), dtype=torch.float32) if self.rotation_type == "kronecker_rotation": - params_dict["kronecker_rotation_n"] = torch.zeros( - (160, 160), dtype=torch.float32) - params_dict["kronecker_rotation_m"] = torch.zeros( - (160, 160), dtype=torch.float32) + params_dict["kronecker_rotation_n"] = torch.zeros((160, 160), dtype=torch.float32) + params_dict["kronecker_rotation_m"] = torch.zeros((160, 160), dtype=torch.float32) return params_dict - def apply_rotation(self, layer: torch.nn.Module, - x: torch.Tensor) -> torch.Tensor: + def apply_rotation(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: """Apply rotation transformation to input tensor.""" init_shape = x.shape dtype = x.dtype @@ -100,8 +88,8 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, ) -> torch.Tensor: dtype = x.dtype x, pertoken_scale = torch_npu.npu_dynamic_quant(x, dst_type=torch.quint4x2) @@ -113,14 +101,14 @@ class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): scale=layer.weight_scale.data.view(-1), pertoken_scale=pertoken_scale, bias=None, - output_dtype=dtype) + output_dtype=dtype, + ) if bias is not None: output = output + bias.to(dtype) return output def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight_scale.data = layer.weight_scale.data.to(torch.float32) - layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( - layer.weight.data.to(torch.int32)) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32)) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(-1, -2) diff --git a/vllm_ascend/quantization/methods/w4a8.py b/vllm_ascend/quantization/methods/w4a8.py index a5fc3afa..8a5ebca2 100644 --- a/vllm_ascend/quantization/methods/w4a8.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -15,7 +15,8 @@ # limitations under the License. # -from typing import Any, Callable, Dict, Optional +from collections.abc import Callable +from typing import Any import numpy as np import torch @@ -27,7 +28,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import maybe_trans_nz, COMPRESSED_TENSORS_METHOD +from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .registry import register_scheme @@ -39,19 +40,17 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme): def __init__(self): vllm_config = get_current_vllm_config() - self.group_size = vllm_config.quant_config.quant_description.get( - "group_size", 256) - quant_version = vllm_config.quant_config.quant_description.get( - "version", "0") + self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256) + quant_version = vllm_config.quant_config.quant_description.get("version", "0") self.new_quant_version = quant_version == "1.0.0" from vllm.distributed import get_tensor_model_parallel_world_size + self.tp_size = get_tensor_model_parallel_world_size() - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: """Create weight parameters. - + For new quantization version (double int4 pack into int8), the output dimension is compressed by factor 2 (e.g., [2048, 3072] -> [1024, 3072]). The returned dict includes "_packed_dim" and "_packed_factor" for vLLM's weight loader. @@ -62,40 +61,26 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme): # double int4 pack into int8: output dimension is compressed pack_factor = 2 actual_output_size = output_size // pack_factor - params_dict["weight"] = torch.empty(actual_output_size, - input_size, - dtype=torch.int8) + params_dict["weight"] = torch.empty(actual_output_size, input_size, dtype=torch.int8) # Add packing information for vLLM's weight_loader params_dict["_packed_dim"] = 0 params_dict["_packed_factor"] = pack_factor else: - params_dict["weight"] = torch.empty(output_size, - input_size, - dtype=torch.int8) + params_dict["weight"] = torch.empty(output_size, input_size, dtype=torch.int8) return params_dict - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: + def get_pergroup_param( + self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None + ) -> dict[str, Any]: """Create per-group quantization parameters.""" params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_scale_second"] = torch.empty(output_size, - input_size // - self.group_size, - dtype=params_dtype) - params_dict["weight_offset_second"] = torch.empty(output_size, - input_size // - self.group_size, - dtype=params_dtype) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_scale_second"] = torch.empty(output_size, input_size // self.group_size, dtype=params_dtype) + params_dict["weight_offset_second"] = torch.empty( + output_size, input_size // self.group_size, dtype=params_dtype + ) # NOTE: In w4a8 quantization implementation, # for down_proj and o_proj(layer_type == "row") scale_bias shape is [output_size, 16], @@ -103,24 +88,21 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme): if self.new_quant_version: scale_bias_dim = 16 if layer_type == "row" else 1 - params_dict["scale_bias"] = torch.empty(output_size, - scale_bias_dim, - dtype=torch.float32) + params_dict["scale_bias"] = torch.empty(output_size, scale_bias_dim, dtype=torch.float32) return params_dict @staticmethod - def process_scale_second(weight: torch.Tensor, - scale: torch.Tensor, - per_group_scale: torch.Tensor, - is_new_quant: bool = False): + def process_scale_second( + weight: torch.Tensor, scale: torch.Tensor, per_group_scale: torch.Tensor, is_new_quant: bool = False + ): """Process the scale for second-level quantization. - + Args: weight: weight tensor [k, n] (in new version, n is already compressed to n/2) scale: first-level quantization scale [output_size] per_group_scale: second-level per-group quantization scale [group_num, n_scale] is_new_quant: whether it's the new quantization version (weight already compressed) - + Returns: (antiquant_scale, bias): dequantization scale and bias (bias=None for new version) """ @@ -133,8 +115,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme): bias = None if not is_new_quant: - weight_high = weight.to(torch.float32).reshape( - group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight.to(torch.float32).reshape(group_num, -1, n) * per_group_scale.reshape(group_num, 1, n) weight_high = weight_high.reshape(k, n) bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) # NOTE: scale_bias is not used currently @@ -148,8 +129,8 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme): self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = None, + bias: torch.Tensor | None = None, + tp_rank: int | None = None, ) -> torch.Tensor: return torch_npu.npu_weight_quant_batchmatmul( x, @@ -161,8 +142,7 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme): def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = maybe_trans_nz(layer.weight.data) - layer.weight_scale.data = layer.weight_scale.data.flatten().to( - torch.float32) + layer.weight_scale.data = layer.weight_scale.data.flatten().to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() layer.weight_scale_second.data, scale_bias = self.process_scale_second( layer.weight.data, @@ -187,15 +167,14 @@ class AscendW4A8DynamicLinearMethod(AscendLinearScheme): if self.new_quant_version: # weights on disk are already in packed int4 format # pack 4 int8(int4*2) to int32 - assert layer.weight.data.shape[-1] % 4 == 0, \ + assert layer.weight.data.shape[-1] % 4 == 0, ( f"the last dim of weight needs to be divided by 4, got shape {layer.weight.data.shape}" - layer.weight.data = layer.weight.data.view( - torch.int32).contiguous() + ) + layer.weight.data = layer.weight.data.view(torch.int32).contiguous() else: # weights are not compressed # need to be packed via npu_convert_weight_to_int4pack - layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( - layer.weight.data.to(torch.int32)) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32)) @register_scheme("W4A8_DYNAMIC", "moe") @@ -209,69 +188,56 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): self.ep_group = get_ep_group() vllm_config = get_current_vllm_config() - self.group_size = vllm_config.quant_config.quant_description.get( - "group_size", 256) + self.group_size = vllm_config.quant_config.quant_description.get("group_size", 256) # NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process self.is_per_channel_weight = self.group_size == 0 - quant_version = vllm_config.quant_config.quant_description.get( - "version", "0") + quant_version = vllm_config.quant_config.quant_description.get("version", "0") # NOTE: new quantize weights: 2 int4 pack into int8 self.new_quant_version = quant_version == "1.0.0" - self.quant_method = vllm_config.quant_config.quant_description.get( - "ascend_quant_method", "") + self.quant_method = vllm_config.quant_config.quant_description.get("ascend_quant_method", "") if self.quant_method == COMPRESSED_TENSORS_METHOD: - self.weight_strategy = vllm_config.quant_config.quant_description.get( - "weight_strategy", "group") + self.weight_strategy = vllm_config.quant_config.quant_description.get("weight_strategy", "group") self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb if self.new_quant_version and self.tp_size > 16: - raise ValueError( - "The current weight does not support moe part tp>16.") + raise ValueError("The current weight does not support moe part tp>16.") try: device_group = get_mc2_group().device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name( - local_rank) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) except AttributeError: self.moe_all_to_all_group_name = "" - def get_weight(self, num_experts: int, - intermediate_size_per_partition: int, hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_weight( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: if self.quant_method == COMPRESSED_TENSORS_METHOD: return self.get_weight_compressed_tensors( - num_experts, intermediate_size_per_partition, - hidden_sizes, params_dtype) + num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype + ) else: - return self.get_weight_modelslim( - num_experts, intermediate_size_per_partition, - hidden_sizes, params_dtype) - - def get_weight_compressed_tensors(self, num_experts: int, - intermediate_size_per_partition: int, hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - + return self.get_weight_modelslim(num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype) + + def get_weight_compressed_tensors( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: param_dict = {} E = num_experts H = hidden_sizes IN = intermediate_size_per_partition - g = self.group_size - param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, - dtype=torch.int8) - param_dict["w2_weight"] = torch.empty(E, H, IN, - dtype=torch.int8) + param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, dtype=torch.int8) + param_dict["w2_weight"] = torch.empty(E, H, IN, dtype=torch.int8) return param_dict - - def get_weight_modelslim(self, num_experts: int, - intermediate_size_per_partition: int, hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_weight_modelslim( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: param_dict = {} if self.new_quant_version: w13_output_size = intermediate_size_per_partition @@ -280,33 +246,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): w13_output_size = 2 * intermediate_size_per_partition w2_output_size = hidden_sizes - param_dict["w13_weight"] = torch.empty(num_experts, - w13_output_size, - hidden_sizes, - dtype=torch.int8) - param_dict["w2_weight"] = torch.empty(num_experts, - w2_output_size, - intermediate_size_per_partition, - dtype=torch.int8) + param_dict["w13_weight"] = torch.empty(num_experts, w13_output_size, hidden_sizes, dtype=torch.int8) + param_dict["w2_weight"] = torch.empty( + num_experts, w2_output_size, intermediate_size_per_partition, dtype=torch.int8 + ) return param_dict - def get_dynamic_quant_param(self, num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_dynamic_quant_param( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: if self.quant_method == COMPRESSED_TENSORS_METHOD: return self.get_dynamic_quant_param_compressed_tensors( - num_experts, intermediate_size_per_partition, - hidden_sizes, params_dtype) + num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype + ) else: return self.get_dynamic_quant_param_modelslim( - num_experts, intermediate_size_per_partition, - hidden_sizes, params_dtype) - - def get_dynamic_quant_param_compressed_tensors(self, num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype + ) + + def get_dynamic_quant_param_compressed_tensors( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: param_dict = {} E = num_experts @@ -318,72 +278,48 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): def _n_scale_cols(in_features: int) -> int: return 1 if g <= 0 else (in_features // g) - param_dict["w13_weight_scale"] = torch.empty( - E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16) + param_dict["w13_weight_scale"] = torch.empty(E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16) - param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), - dtype=torch.bfloat16) + param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), dtype=torch.bfloat16) return param_dict - def get_dynamic_quant_param_modelslim(self, num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_dynamic_quant_param_modelslim( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: param_dict = {} param_dict["w13_weight_scale"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32) + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ) param_dict["w13_weight_offset"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32) + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ) - param_dict["w2_weight_scale"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=torch.float32) - param_dict["w2_weight_offset"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=torch.float32) + param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32) + param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32) if not self.is_per_channel_weight: param_dict["w13_weight_scale_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.float32) + num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32 + ) param_dict["w13_weight_offset_second"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.float32) + num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.float32 + ) param_dict["w2_weight_scale_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32) + num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32 + ) param_dict["w2_weight_offset_second"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32) + num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.float32 + ) if self.new_quant_version: param_dict["w13_scale_bias"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32) - param_dict["w2_scale_bias"] = torch.empty(num_experts, - hidden_sizes, - 16 // self.tp_size, - dtype=torch.float32) + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ) + param_dict["w2_scale_bias"] = torch.empty( + num_experts, hidden_sizes, 16 // self.tp_size, dtype=torch.float32 + ) return param_dict @@ -396,21 +332,22 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): renormalize: bool, use_grouped_topk: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - log2phy: Optional[torch.Tensor] = None, + log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, **kwargs, ) -> torch.Tensor: - assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( + "Number of global experts mismatch (excluding redundancy)" + ) # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern topk_weights, topk_ids = select_experts( @@ -424,18 +361,17 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) + global_num_experts=global_num_experts, + ) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), - global_num_experts - - global_redundant_expert_num, - device=topk_ids.device) - topk_ids = torch.argsort( - random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) + random_matrix = torch.rand( + topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device + ) + topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) topk_weights = topk_weights.to(x.dtype) @@ -446,25 +382,23 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): w2=[layer.w2_weight], w1_scale=[layer.w13_weight_scale], w2_scale=[layer.w2_weight_scale], - w1_scale_bias=layer.w13_scale_bias if hasattr( - layer, "w13_scale_bias") else None, - w2_scale_bias=layer.w2_scale_bias if hasattr( - layer, "w2_scale_bias") else None, + w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None, + w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None, topk_weights=topk_weights, topk_ids=topk_ids, use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask", None)) + mc2_mask=kwargs.get("mc2_mask"), + ) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): scale = scale.transpose(1, 2).contiguous() if self.is_per_channel_weight: scale_np = scale.cpu().numpy() scale_np.dtype = np.uint32 - scale_uint64_tensor = torch.from_numpy(scale_np.astype( - np.int64)).npu() + scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu() return scale_uint64_tensor, None per_group_scale = per_group_scale.transpose(1, 2).contiguous() group_num, k, n = weight.shape @@ -475,32 +409,27 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): group_num, quantgroup_num, n = per_group_scale.shape bias = None if not self.new_quant_version: - weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ - per_group_scale.reshape([group_num, quantgroup_num, 1, n]) + weight_high = weight.to(torch.float32).reshape( + [group_num, quantgroup_num, -1, n] + ) * per_group_scale.reshape([group_num, quantgroup_num, 1, n]) weight_high = weight_high.reshape([group_num, k, n]) bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1) - scale_fp32 = (scale * per_group_scale).to(torch.float16).to( - torch.float32) + scale_fp32 = (scale * per_group_scale).to(torch.float16).to(torch.float32) scale_fp32_np = scale_fp32.cpu().numpy() scale_fp32_np.dtype = np.uint32 - sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), - dtype=np.uint32) + sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2), dtype=np.uint32) sscale_uint64[..., ::2] = scale_fp32_np - sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), - dtype=np.int64).copy() - sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape( - group_num, quantgroup_num, n) + sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(), dtype=np.int64).copy() + sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(group_num, quantgroup_num, n) sscale_uint64_tensor = sscale_uint64_tensor.npu() return sscale_uint64_tensor, bias def update_bias(self, layer, w13_bias, w2_bias): if self.new_quant_version: - layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose( - 1, 2).contiguous().sum(axis=1) - layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose( - 1, 2).contiguous().sum(axis=1) + layer.w13_scale_bias.data = layer.w13_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1) + layer.w2_scale_bias.data = layer.w2_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1) else: w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) layer.register_parameter("w13_scale_bias", w13_scale_bias) @@ -510,13 +439,12 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): def pack_to_int32(self, weight: torch.Tensor): if self.new_quant_version: # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 - assert weight.shape[ - -1] % 4 == 0, "the last dim of weight needs to be divided by 4" + assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4" return weight.view(torch.int32).contiguous() else: - return torch_npu.npu_quantize(weight.to(torch.float32), - torch.tensor([1.]).npu(), None, - torch.quint4x2, -1, False) + return torch_npu.npu_quantize( + weight.to(torch.float32), torch.tensor([1.0]).npu(), None, torch.quint4x2, -1, False + ) def process_weights_after_loading(self, layer): if self.quant_method == COMPRESSED_TENSORS_METHOD: @@ -524,23 +452,18 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): else: self.process_weights_after_loading_modelslim(layer) - def process_weights_after_loading_compressed_tensors(self, layer): - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose(1, - 2).contiguous() - + layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() + def process_scale_compressed_tensors(scale: torch.Tensor): scale = scale.transpose(1, 2).to(torch.float32).contiguous() scale_np = scale.cpu().numpy() scale_np.dtype = np.uint32 - scale_uint64_tensor = torch.from_numpy(scale_np.astype( - np.int64)).npu() + scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu() return scale_uint64_tensor - def update_bias_compressed_tensors(weight: torch.Tensor, - scale: torch.Tensor, strategy:str): + def update_bias_compressed_tensors(weight: torch.Tensor, scale: torch.Tensor, strategy: str): group_num, k, n = weight.shape scale = scale.transpose(1, 2).contiguous() scale = scale.reshape(group_num, -1, n) @@ -548,8 +471,9 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): bias = None if strategy == "group": - tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \ - scale.reshape([group_num, quantgroup_num, 1, n]) + tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * scale.reshape( + [group_num, quantgroup_num, 1, n] + ) tmp = tmp.reshape([group_num, k, n]) bias = 8 * tmp.sum(axis=1) elif strategy == "channel": @@ -558,19 +482,14 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): raise ValueError(f"Unsupported weight strategy: {strategy}") return bias - w13_bias = update_bias_compressed_tensors(layer.w13_weight.data, - layer.w13_weight_scale.data, - self.weight_strategy) - w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, - layer.w2_weight_scale.data, - self.weight_strategy) + w13_bias = update_bias_compressed_tensors( + layer.w13_weight.data, layer.w13_weight_scale.data, self.weight_strategy + ) + w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, layer.w2_weight_scale.data, self.weight_strategy) - layer.w13_weight_scale.data = process_scale_compressed_tensors( - layer.w13_weight_scale.data) - layer.w2_weight_scale.data = process_scale_compressed_tensors( - layer.w2_weight_scale.data) + layer.w13_weight_scale.data = process_scale_compressed_tensors(layer.w13_weight_scale.data) + layer.w2_weight_scale.data = process_scale_compressed_tensors(layer.w2_weight_scale.data) - w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) layer.register_parameter("w13_scale_bias", w13_scale_bias) w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) @@ -583,21 +502,19 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) def process_weights_after_loading_modelslim(self, layer): - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose(1, - 2).contiguous() + layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() - w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr( - layer, "w13_weight_scale_second") else None - w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr( - layer, "w2_weight_scale_second") else None + w13_weight_scale_second = ( + layer.w13_weight_scale_second.data if hasattr(layer, "w13_weight_scale_second") else None + ) + w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(layer, "w2_weight_scale_second") else None layer.w13_weight_scale.data, w13_bias = self.process_scale( - layer.w13_weight, layer.w13_weight_scale.data, - w13_weight_scale_second) + layer.w13_weight, layer.w13_weight_scale.data, w13_weight_scale_second + ) layer.w2_weight_scale.data, w2_bias = self.process_scale( - layer.w2_weight, layer.w2_weight_scale.data, - w2_weight_scale_second) + layer.w2_weight, layer.w2_weight_scale.data, w2_weight_scale_second + ) if hasattr(layer, "w13_weight_scale_second"): # scale_second is no longer used, release this part of the memory del layer.w13_weight_scale_second diff --git a/vllm_ascend/quantization/methods/w8a16.py b/vllm_ascend/quantization/methods/w8a16.py index 97fc0468..b70d56bc 100644 --- a/vllm_ascend/quantization/methods/w8a16.py +++ b/vllm_ascend/quantization/methods/w8a16.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any import torch import torch_npu @@ -29,7 +29,7 @@ from .registry import register_scheme @register_scheme("W8A16", "linear") class AscendW8A16LinearMethod(AscendLinearScheme): """Linear method for Ascend W8A16. - + This scheme uses 8-bit quantized weights with 16-bit activations. """ @@ -41,39 +41,34 @@ class AscendW8A16LinearMethod(AscendLinearScheme): input_size: int, output_size: int, params_dtype: torch.dtype = torch.bfloat16, - ) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + ) -> dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict def get_perchannel_param( self, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=params_dtype) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) return params_dict def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, ) -> torch.Tensor: output = torch_npu.npu_weight_quant_batchmatmul( x=x, weight=layer.weight, antiquant_scale=layer.weight_scale, antiquant_offset=layer.weight_offset, - bias=bias) + bias=bias, + ) return output def process_weights_after_loading(self, layer): diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index d5e5db64..68dea550 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -15,7 +15,8 @@ # limitations under the License. # -from typing import Any, Callable, Dict, Optional +from collections.abc import Callable +from typing import Any import torch import torch_npu @@ -28,8 +29,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.flash_common3_context import get_flash_common3_context -from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, - zero_experts_compute) +from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz from .base import AscendLinearScheme, AscendMoEScheme, QuantType @@ -39,16 +39,17 @@ from .registry import register_scheme def scale_from_float_to_int64(scale): """Convert float32 scale to int64 representation.""" import numpy as np + scale = torch.from_numpy( - np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), - dtype=np.int32).astype(np.int64)).to(scale.device) + np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64) + ).to(scale.device) return scale @register_scheme("W8A8_DYNAMIC", "linear") class AscendW8A8DynamicLinearMethod(AscendLinearScheme): """Linear method for Ascend W8A8_DYNAMIC. - + This scheme uses dynamic per-token quantization for activations and per-channel quantization for weights. """ @@ -56,33 +57,26 @@ class AscendW8A8DynamicLinearMethod(AscendLinearScheme): def __init__(self): pass - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict def get_perchannel_param( self, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=params_dtype) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) return params_dict def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, ) -> torch.Tensor: quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x) output = torch_npu.npu_quant_matmul( @@ -116,9 +110,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): vllm_config = get_current_vllm_config() ascend_config = get_ascend_config() - self.use_aclgraph = (vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE - and not vllm_config.model_config.enforce_eager) + self.use_aclgraph = ( + vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE + and not vllm_config.model_config.enforce_eager + ) self.multistream_overlap_gate = ascend_config.multistream_overlap_gate self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb @@ -130,49 +125,34 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name( - local_rank) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) except AttributeError: self.moe_all_to_all_group_name = "" - def get_weight(self, num_experts: int, - intermediate_size_per_partition: int, hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_weight( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: param_dict = {} - param_dict["w13_weight"] = torch.empty(num_experts, - 2 * - intermediate_size_per_partition, - hidden_sizes, - dtype=torch.int8) - param_dict["w2_weight"] = torch.empty(num_experts, - hidden_sizes, - intermediate_size_per_partition, - dtype=torch.int8) + param_dict["w13_weight"] = torch.empty( + num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.int8 + ) + param_dict["w2_weight"] = torch.empty( + num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.int8 + ) return param_dict - def get_dynamic_quant_param(self, num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_dynamic_quant_param( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: param_dict = {} param_dict["w13_weight_scale"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=params_dtype) + num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype + ) param_dict["w13_weight_offset"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=params_dtype) - param_dict["w2_weight_scale"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=params_dtype) - param_dict["w2_weight_offset"] = torch.empty(num_experts, - hidden_sizes, - 1, - dtype=params_dtype) + num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype + ) + param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype) + param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype) return param_dict def apply( @@ -184,25 +164,26 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): renormalize: bool, use_grouped_topk: bool = False, global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, + e_score_correction_bias: torch.Tensor | None = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - log2phy: Optional[torch.Tensor] = None, + log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, - pertoken_scale: Optional[Any] = None, + pertoken_scale: Any | None = None, **kwargs, ) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) if zero_expert_num == 0 or zero_expert_type is None: - assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \ + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( "Number of global experts mismatch (excluding redundancy)" + ) if self.multistream_overlap_gate: fc3_context = get_flash_common3_context() @@ -222,7 +203,8 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) + global_num_experts=global_num_experts, + ) assert topk_ids is not None assert topk_weights is not None if zero_expert_num > 0 and zero_expert_type is not None: @@ -237,12 +219,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), - global_num_experts - - global_redundant_expert_num, - device=topk_ids.device) - topk_ids = torch.argsort( - random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) + random_matrix = torch.rand( + topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device + ) + topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) assert topk_weights is not None topk_weights = topk_weights.to(self.in_dtype) @@ -259,9 +239,10 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): w2 = [layer.w2_weight] w2_scale = [layer.w2_weight_scale] - fused_scale_flag = (get_forward_context().moe_comm_type - == MoECommType.FUSED_MC2 - and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1) + fused_scale_flag = ( + get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 + and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1 + ) final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, @@ -275,54 +256,35 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): expert_map=expert_map, log2phy=log2phy, dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask", None)) + mc2_mask=kwargs.get("mc2_mask"), + ) if zero_expert_num > 0 and zero_expert_type is not None: final_hidden_states += zero_expert_result return final_hidden_states def process_weights_after_loading(self, layer): - layer.w13_weight.data = layer.w13_weight.data.transpose( - 1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose(1, - 2).contiguous() + layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() # TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant` # can only support weight nz. - layer.w13_weight.data = torch_npu.npu_format_cast( - layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) - layer.w2_weight.data = torch_npu.npu_format_cast( - layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) - layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( - layer.w13_weight_scale.data.shape[0], -1) - layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( - torch.float32) - layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( - layer.w13_weight_offset.data.shape[0], -1) - layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( - layer.w2_weight_scale.data.shape[0], -1) - layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( - layer.w2_weight_offset.data.shape[0], -1) + layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(torch.float32) + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1) + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1) - layer.fused_w1_scale = scale_from_float_to_int64( - layer.w13_weight_scale.data) - layer.fused_w2_scale = scale_from_float_to_int64( - layer.w2_weight_scale.data) + layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data) + layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data) if self.dynamic_eplb: - layer.w13_weight_list = [ - weight.clone() - for weight in layer.w13_weight.data.unbind(dim=0) - ] - layer.w2_weight_list = [ - weight.clone() for weight in layer.w2_weight.data.unbind(dim=0) - ] + layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)] + layer.w2_weight_list = [weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)] layer.w13_weight_scale_fp32_list = [ - weight.clone() - for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0) - ] - layer.w2_weight_scale_list = [ - weight.clone() - for weight in layer.w2_weight_scale.data.unbind(dim=0) + weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0) ] + layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)] del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index e42abfae..dc772952 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any import torch import torch_npu @@ -28,48 +28,37 @@ from .registry import register_scheme @register_scheme("W8A8_MXFP8", "linear") class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): """Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization. - + This scheme uses microscaling FP8 quantization with per-group scales. The activation is dynamically quantized to FP8 (E4M3FN format) with microscaling, and weights are stored in FP8 format with per-group scales. """ + model_dtype = None def __init__(self): vllm_config = get_current_vllm_config() - self.group_size = vllm_config.quant_config.quant_description.get( - "group_size", 32) + self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32) - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - params_dict = { - "weight": - torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn) - } + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)} return params_dict - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: + def get_pergroup_param( + self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None + ) -> dict[str, Any]: params_dict = {} - params_dict["weight_scale"] = torch.empty(output_size, - input_size // - self.group_size, - dtype=torch.uint8) + params_dict["weight_scale"] = torch.empty(output_size, input_size // self.group_size, dtype=torch.uint8) return params_dict def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, ) -> torch.Tensor: - - quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant( - x, dst_type=torch.float8_e4m3fn) + quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn) pertoken_scale = dynamic_scale output_dtype = x.dtype @@ -82,13 +71,13 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): pertoken_scale_dtype=torch_npu.float8_e8m0fnu, bias=bias, output_dtype=output_dtype, - group_sizes=[1, 1, self.group_size]) + group_sizes=[1, 1, self.group_size], + ) return output def process_weights_after_loading(self, layer): n_dim, k_dim = layer.weight_scale.data.shape - layer.weight_scale.data = layer.weight_scale.data.reshape( - n_dim, k_dim // 2, 2) + layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) layer.weight.data = layer.weight.data.transpose(0, 1) layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) diff --git a/vllm_ascend/quantization/methods/w8a8_pdmix.py b/vllm_ascend/quantization/methods/w8a8_pdmix.py index 6c7cd394..ee2af55e 100644 --- a/vllm_ascend/quantization/methods/w8a8_pdmix.py +++ b/vllm_ascend/quantization/methods/w8a8_pdmix.py @@ -22,15 +22,14 @@ for prefill and decode phases: - Decode (KV consumer): Uses static W8A8 quantization """ -from typing import Any, Dict, Optional +from typing import Any import torch from vllm.config import get_current_vllm_config from .base import AscendLinearScheme from .registry import register_scheme -from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, - AscendW8A8DynamicLinearMethod) +from .w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod from .w8a8_static import AscendW8A8LinearMethod @@ -53,31 +52,27 @@ class AscendW8A8PDMixLinearMethod(AscendLinearScheme): self._dynamic_method = AscendW8A8DynamicLinearMethod() kv_transfer_config = get_current_vllm_config().kv_transfer_config - self._is_kv_consumer = (kv_transfer_config is not None - and kv_transfer_config.is_kv_consumer) + self._is_kv_consumer = kv_transfer_config is not None and kv_transfer_config.is_kv_consumer - def get_weight(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - return self._static_method.get_weight(input_size, output_size, - params_dtype) + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: + return self._static_method.get_weight(input_size, output_size, params_dtype) - def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]: return self._static_method.get_pertensor_param(params_dtype) def get_perchannel_param( self, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: - return self._static_method.get_perchannel_param( - output_size, params_dtype) + ) -> dict[str, Any]: + return self._static_method.get_perchannel_param(output_size, params_dtype) def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, ) -> torch.Tensor: if layer.is_kv_consumer: return self._static_method.apply(layer, x, bias, tp_rank) @@ -92,26 +87,15 @@ class AscendW8A8PDMixLinearMethod(AscendLinearScheme): @register_scheme("W8A8_MIX", "moe") class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod): - - def get_dynamic_quant_param(self, num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def get_dynamic_quant_param( + self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype + ) -> dict[str, Any]: param_dict = super().get_dynamic_quant_param( - num_experts, intermediate_size_per_partition, hidden_sizes, - params_dtype) - param_dict["w2_deq_scale"] = torch.empty(num_experts, - hidden_sizes, - dtype=torch.float32) - param_dict["w13_deq_scale"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - dtype=torch.float32) - param_dict["w2_input_offset"] = torch.empty(num_experts, - 1, - dtype=torch.int8) - param_dict["w13_input_offset"] = torch.empty(num_experts, - 1, - dtype=torch.int8) + num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype + ) + param_dict["w2_deq_scale"] = torch.empty(num_experts, hidden_sizes, dtype=torch.float32) + param_dict["w13_deq_scale"] = torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32) + param_dict["w2_input_offset"] = torch.empty(num_experts, 1, dtype=torch.int8) + param_dict["w13_input_offset"] = torch.empty(num_experts, 1, dtype=torch.int8) return param_dict diff --git a/vllm_ascend/quantization/methods/w8a8_static.py b/vllm_ascend/quantization/methods/w8a8_static.py index 3a00b4eb..a2101010 100644 --- a/vllm_ascend/quantization/methods/w8a8_static.py +++ b/vllm_ascend/quantization/methods/w8a8_static.py @@ -15,14 +15,16 @@ # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any import torch import torch_npu -from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType, - get_ascend_device_type, - get_weight_prefetch_method, maybe_trans_nz) +from vllm_ascend.utils import ( + COMPRESSED_TENSORS_METHOD, + get_weight_prefetch_method, + maybe_trans_nz, +) from .base import AscendLinearScheme from .registry import register_scheme @@ -44,13 +46,11 @@ class AscendW8A8LinearMethod(AscendLinearScheme): input_size: int, output_size: int, params_dtype: torch.dtype = torch.bfloat16, - ) -> Dict[str, Any]: - params_dict = { - "weight": torch.empty(output_size, input_size, dtype=torch.int8) - } + ) -> dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict - def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]: params_dict = {} params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) @@ -60,29 +60,23 @@ class AscendW8A8LinearMethod(AscendLinearScheme): self, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: params_dict = {} params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) if params_dtype == torch.bfloat16: - params_dict["deq_scale"] = torch.empty(output_size, - dtype=torch.float32) + params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32) elif params_dtype == torch.float16: - params_dict["deq_scale"] = torch.empty(output_size, - dtype=torch.int64) - params_dict["weight_scale"] = torch.empty(output_size, - 1, - dtype=params_dtype) - params_dict["weight_offset"] = torch.empty(output_size, - 1, - dtype=params_dtype) + params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64) + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) return params_dict def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - tp_rank: Optional[int] = 0, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, ) -> torch.Tensor: if x.dtype != torch.int8: layer_cls_name = layer.__class__.__name__ @@ -95,15 +89,15 @@ class AscendW8A8LinearMethod(AscendLinearScheme): start_flag=x, ) try: - quant_comm_config = getattr(layer, "_quant_comm_config") + quant_comm_config = layer._quant_comm_config except AttributeError: quant_comm_config = {} comm_fn = quant_comm_config.get("communication_fn") enable_flashcomm2_quant_comm = comm_fn is not None and ( - "o_proj" in layer.prefix or "out_proj" in layer.prefix) + "o_proj" in layer.prefix or "out_proj" in layer.prefix + ) if enable_flashcomm2_quant_comm: - quant_input_x = x.contiguous().view( - -1, layer.aclnn_input_scale_reciprocal.size(0)) + quant_input_x = x.contiguous().view(-1, layer.aclnn_input_scale_reciprocal.size(0)) quant_x = torch.ops.vllm.quantize( quant_input_x, layer.aclnn_input_scale, @@ -132,7 +126,7 @@ class AscendW8A8LinearMethod(AscendLinearScheme): quant_bias = layer.quant_bias if tp_rank == 0 else None try: - ascend_quant_method = getattr(layer, "ascend_quant_method") + ascend_quant_method = layer.ascend_quant_method except AttributeError: ascend_quant_method = "" if ascend_quant_method == COMPRESSED_TENSORS_METHOD: @@ -150,14 +144,14 @@ class AscendW8A8LinearMethod(AscendLinearScheme): def process_weights_after_loading(self, layer): expanding_factor = layer.weight.data.shape[1] layer.aclnn_input_scale = torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor), - requires_grad=False) + layer.input_scale.data.repeat(expanding_factor), requires_grad=False + ) layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor), - requires_grad=False) + layer.input_scale.data.repeat(expanding_factor), requires_grad=False + ) layer.aclnn_input_offset = torch.nn.Parameter( - layer.input_offset.data.repeat(expanding_factor), - requires_grad=False).to(layer.aclnn_input_scale.dtype) + layer.input_offset.data.repeat(expanding_factor), requires_grad=False + ).to(layer.aclnn_input_scale.dtype) layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = maybe_trans_nz(layer.weight.data) @@ -166,5 +160,4 @@ class AscendW8A8LinearMethod(AscendLinearScheme): ascend_quant_method = getattr(layer, "ascend_quant_method", "") if ascend_quant_method == COMPRESSED_TENSORS_METHOD: deq_scale = layer.input_scale.data * layer.weight_scale.data - layer.deq_scale = torch.nn.Parameter(deq_scale, - requires_grad=False) + layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 227b97fb..5034604e 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -21,20 +21,18 @@ This module provides the AscendModelSlimConfig class for parsing quantization configs generated by the ModelSlim tool, along with model-specific mappings. """ +from collections.abc import Mapping from types import MappingProxyType -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Optional import torch from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.layers.quantization import \ - register_quantization_config -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - UnquantizedEmbeddingMethod, VocabParallelEmbedding) +from vllm.model_executor.layers.quantization import register_quantization_config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase +from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding from vllm.model_executor.models.utils import WeightsMapper from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD @@ -45,7 +43,7 @@ logger = init_logger(__name__) # key: model_type # value: orig_to_new_prefix -QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = { +QUANT_MODEL_PREFIX_MAPPINGS: dict[str, dict[str, str]] = { "qwen3_vl_moe": { "visual.": "model.visual.", "language_model.lm_head.": "lm_head.", @@ -60,7 +58,7 @@ QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = { # key: model_type # value: dict of fused module name -> list of original module names -packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { +packed_modules_model_mapping: dict[str, dict[str, list[str]]] = { "qwen3_moe": { "qkv_proj": [ "q_proj", @@ -71,52 +69,44 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { "gate_proj", "up_proj", ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], }, "deepseek_v2": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, "deepseek_v3": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, "pangu_ultra_moe": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, "kimi_k2": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, "deepseek_v32": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; # NOTE 2.The description file generated by the current msmodelslim tool does not have # MTP layer info. Please manually add it and set the value to FLOAT. "deepseek_mtp": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], }, "pangu_ultra_moe_mtp": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, "qwen3_next": { "qkv_proj": [ @@ -126,8 +116,7 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { ], "gate_up_proj": ["gate_proj", "up_proj"], "in_proj": ["in_proj_qkvz", "in_proj_ba"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], }, "qwen2_5_vl": { "qkv_proj": [ @@ -150,8 +139,7 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { "gate_proj", "up_proj", ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], }, "glm4_moe": { "qkv_proj": [ @@ -163,20 +151,17 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { "gate_proj", "up_proj", ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], }, - "glm4_moe_lite": { + "glm4_moe_lite": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, "longcat_flash": { "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"], }, "minimax_m2": { "qkv_proj": [ @@ -184,17 +169,17 @@ packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { "k_proj", "v_proj", ], - "experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"] - } + "experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"], + }, } -def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]: +def get_packed_modules_mapping(model_type: str) -> dict[str, list[str]]: """Get packed modules mapping for a model type. - + Args: model_type: The model type string (e.g., "deepseek_v3"). - + Returns: Dictionary mapping fused module names to their component module names. Returns empty dict if model_type is not found. @@ -202,12 +187,12 @@ def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]: return packed_modules_model_mapping.get(model_type, {}) -def get_prefix_mapping(model_type: str) -> Dict[str, str]: +def get_prefix_mapping(model_type: str) -> dict[str, str]: """Get prefix mapping for a model type. - + Args: model_type: The model type string (e.g., "qwen3_vl_moe"). - + Returns: Dictionary mapping original prefixes to new prefixes. Returns empty dict if model_type is not found. @@ -216,15 +201,15 @@ def get_prefix_mapping(model_type: str) -> Dict[str, str]: def get_linear_quant_type( - quant_description: Dict[str, Any], prefix: str, - packed_modules_mapping: Dict[str, Any]) -> Optional[str]: + quant_description: dict[str, Any], prefix: str, packed_modules_mapping: dict[str, Any] +) -> str | None: """Determine the quantization type for a linear layer. - + Args: quant_description: The quantization description dictionary. prefix: The layer prefix. packed_modules_mapping: Mapping for packed/fused modules. - + Returns: The quantization type string (e.g., "W8A8_DYNAMIC"). """ @@ -232,11 +217,10 @@ def get_linear_quant_type( if proj_name in packed_modules_mapping: quant_type = None shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in packed_modules_mapping[proj_name] + prefix.replace(proj_name, shard_proj_name) for shard_proj_name in packed_modules_mapping[proj_name] ] for shard_prefix in shard_prefixes: - shard_quant_type = quant_description[shard_prefix + '.weight'] + shard_quant_type = quant_description[shard_prefix + ".weight"] if quant_type is None: quant_type = shard_quant_type @@ -244,72 +228,68 @@ def get_linear_quant_type( raise ValueError( f"Not all shards of {prefix} are quantized with same quant type." f"Shard {proj_name} uses {shard_quant_type}, but another shard" - f"use {quant_type}. Please check quantization config.") + f"use {quant_type}. Please check quantization config." + ) else: - quant_type = quant_description[prefix + '.weight'] + quant_type = quant_description[prefix + ".weight"] return quant_type def get_quant_type_for_layer( - quant_description: Dict[str, Any], - prefix: str, - layer_type: str, - packed_modules_mapping: Optional[Dict[str, - Any]] = None) -> Optional[str]: + quant_description: dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: dict[str, Any] | None = None, +) -> str | None: """Determine the quantization type for a layer. - + Args: quant_description: The quantization description dictionary. prefix: The layer prefix. layer_type: The type of layer ("linear", "moe", "attention"). packed_modules_mapping: Mapping for packed/fused modules. - + Returns: The quantization type string (e.g., "W8A8_DYNAMIC"). """ if packed_modules_mapping is None: packed_modules_mapping = dict() # Attention - if layer_type == "attention" and 'fa_quant_type' in quant_description.keys( - ): - return quant_description['fa_quant_type'] + if layer_type == "attention" and "fa_quant_type" in quant_description: + return quant_description["fa_quant_type"] # Linear / MoE - return get_linear_quant_type(quant_description, prefix, - packed_modules_mapping) + return get_linear_quant_type(quant_description, prefix, packed_modules_mapping) def create_scheme_for_layer( - quant_description: Dict[str, Any], - prefix: str, - layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): + quant_description: dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: dict[str, Any] | None = None, +): """Create a quantization scheme instance for a layer. - + Args: quant_description: The quantization description dictionary. prefix: The layer prefix. layer_type: The type of layer ("linear", "moe", "attention"). packed_modules_mapping: Mapping for packed/fused modules. - + Returns: An instance of the appropriate quantization scheme class. """ logger.info_once("Using the vLLM Ascend modelslim Quantization now!") - quant_type = get_quant_type_for_layer(quant_description, prefix, - layer_type, packed_modules_mapping) + quant_type = get_quant_type_for_layer(quant_description, prefix, layer_type, packed_modules_mapping) if quant_type is None: - raise ValueError( - f"Could not determine quantization type for layer {prefix}.") + raise ValueError(f"Could not determine quantization type for layer {prefix}.") # Use registry to get scheme class scheme_cls = get_scheme_class(quant_type, layer_type) if scheme_cls is not None: return scheme_cls() - raise NotImplementedError( - f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." - ) + raise NotImplementedError(f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}.") @register_quantization_config(ASCEND_QUANTIZATION_METHOD) @@ -321,13 +301,13 @@ class AscendModelSlimConfig(QuantizationConfig): quantized using the ModelSlim tool. """ - def __init__(self, quant_config: Dict[str, Any]): + def __init__(self, quant_config: dict[str, Any]): super().__init__() self.quant_description = quant_config # TODO(whx): remove this adaptation after adding "shared_head" # to prefix of DeepSeekShareHead in vLLM. extra_quant_dict = {} - for k in self.quant_description.keys(): + for k in self.quant_description: if "shared_head" in k: new_k = k.replace(".shared_head.", ".") extra_quant_dict[new_k] = self.quant_description[k] @@ -344,25 +324,23 @@ class AscendModelSlimConfig(QuantizationConfig): return ASCEND_QUANTIZATION_METHOD @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.int8, torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: - raise NotImplementedError( - "Ascend hardware dose not support \"get_min_capability\" feature.") + raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.') @classmethod - def get_config_filenames(cls) -> List[str]: + def get_config_filenames(cls) -> list[str]: return ["quant_model_description.json"] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AscendModelSlimConfig": + def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig": return cls(config) @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None: if hf_quant_cfg is not None: quant_method = hf_quant_cfg.get("quant_method", None) if not quant_method and torch.npu.is_available(): @@ -373,15 +351,17 @@ class AscendModelSlimConfig(QuantizationConfig): # TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) if prefix_mapping: - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix=prefix_mapping) + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix=prefix_mapping) return hf_to_vllm_mapper._map_name(prefix) return prefix - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: - from .method_adapters import (AscendEmbeddingMethod, AscendFusedMoEMethod, - AscendKVCacheMethod, AscendLinearMethod) + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: + from .method_adapters import ( + AscendEmbeddingMethod, + AscendFusedMoEMethod, + AscendKVCacheMethod, + AscendLinearMethod, + ) vllm_config = get_current_vllm_config() model_type = vllm_config.model_config.hf_config.model_type @@ -390,81 +370,67 @@ class AscendModelSlimConfig(QuantizationConfig): # Adapt to Minimax architecture: update layer names to MoE convention prefix = prefix.replace("mlp", "block_sparse_moe") # Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts') - parts = prefix.split('.') + parts = prefix.split(".") if "experts" in parts and len(parts) > 2: exp_idx = parts.index("experts") if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit(): - parts = parts[:exp_idx + 1] + parts = parts[: exp_idx + 1] prefix = ".".join(parts) if model_type in packed_modules_model_mapping: - self.packed_modules_mapping = packed_modules_model_mapping[ - model_type] + self.packed_modules_mapping = packed_modules_model_mapping[model_type] prefix = self.quant_prefix_mapper(model_type, prefix) from vllm_ascend.utils import vllm_version_is + if vllm_version_is("v0.15.0"): - from vllm.attention.layer import Attention # type: ignore + from vllm.attention.layer import Attention # type: ignore else: from vllm.model_executor.layers.attention import Attention if prefix.startswith("language_model"): - prefix = prefix.split('.', 1)[-1] + prefix = prefix.split(".", 1)[-1] if isinstance(layer, LinearBase): - if self.is_layer_skipped_ascend(prefix, - self.packed_modules_mapping): + if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): # Delayed import to avoid circular import - from vllm_ascend.ops.linear import \ - AscendUnquantizedLinearMethod + from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod + return AscendUnquantizedLinearMethod() - scheme = create_scheme_for_layer(self.quant_description, prefix, - "linear", - self.packed_modules_mapping) + scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) return AscendLinearMethod(scheme) - elif isinstance(layer, Attention) and \ - 'fa_quant_type' in self.quant_description.keys() and \ - self.quant_description['fa_quant_type'] is not None: - scheme = create_scheme_for_layer(self.quant_description, prefix, - "attention", - self.packed_modules_mapping) + elif ( + isinstance(layer, Attention) + and "fa_quant_type" in self.quant_description + and self.quant_description["fa_quant_type"] is not None + ): + scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) return AscendKVCacheMethod(scheme) elif isinstance(layer, FusedMoE): - if self.is_layer_skipped_ascend(prefix, - self.packed_modules_mapping): + if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): # Delayed import to avoid circular import - from vllm_ascend.ops.fused_moe.fused_moe import \ - AscendUnquantizedFusedMoEMethod + from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod + return AscendUnquantizedFusedMoEMethod(layer.moe_config) - scheme = create_scheme_for_layer(self.quant_description, prefix, - "moe", - self.packed_modules_mapping) + scheme = create_scheme_for_layer(self.quant_description, prefix, "moe", self.packed_modules_mapping) return AscendFusedMoEMethod(scheme, layer.moe_config) elif isinstance(layer, VocabParallelEmbedding): - if self.is_layer_skipped_ascend(prefix, - self.packed_modules_mapping): + if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): return UnquantizedEmbeddingMethod() - scheme = create_scheme_for_layer(self.quant_description, prefix, - "linear", - self.packed_modules_mapping) + scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) return AscendEmbeddingMethod(scheme) return None - def is_layer_skipped_ascend( - self, - prefix: str, - fused_mapping: Mapping[str, List[str]] = MappingProxyType({})): + def is_layer_skipped_ascend(self, prefix: str, fused_mapping: Mapping[str, list[str]] = MappingProxyType({})): # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped proj_name = prefix.split(".")[-1] if proj_name in fused_mapping: shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in fused_mapping[proj_name] + prefix.replace(proj_name, shard_proj_name) for shard_proj_name in fused_mapping[proj_name] ] is_skipped = None for shard_prefix in shard_prefixes: - is_shard_skipped = self.quant_description[shard_prefix + - '.weight'] == "FLOAT" + is_shard_skipped = self.quant_description[shard_prefix + ".weight"] == "FLOAT" if is_skipped is None: is_skipped = is_shard_skipped @@ -472,12 +438,13 @@ class AscendModelSlimConfig(QuantizationConfig): raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " - "to have the same precision.") + "to have the same precision." + ) else: - is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" + is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT" assert is_skipped is not None return is_skipped - def get_scaled_act_names(self) -> List[str]: + def get_scaled_act_names(self) -> list[str]: return [] diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 7d8b8078..f048dcce 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -1,18 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional import torch from vllm.triton_utils import HAS_TRITON, triton from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE, MAX_SPEC_LEN, - PLACEHOLDER_TOKEN_ID, - generate_uniform_probs) +from vllm.v1.sample.rejection_sampler import ( + GREEDY_TEMPERATURE, + MAX_SPEC_LEN, + PLACEHOLDER_TOKEN_ID, + generate_uniform_probs, +) from vllm_ascend.ops.triton.reject_sample import ( - cal_grid_and_block_size, expand_triton, + cal_grid_and_block_size, + expand_triton, rejection_greedy_sample_with_triton, rejection_random_sample_block_verify_kernel, - rejection_random_sample_kernel, sample_recovered_tokens_kernel) + rejection_random_sample_kernel, + sample_recovered_tokens_kernel, +) from vllm_ascend.sample.sampler import apply_top_k_top_p @@ -83,7 +88,7 @@ def rejection_sample( # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] - draft_probs: Optional[torch.Tensor], + draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] target_probs: torch.Tensor, # [batch_size, 1] @@ -126,15 +131,20 @@ def rejection_sample( # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) if HAS_TRITON: - rejection_greedy_sample_with_triton(output_token_ids, - num_draft_tokens, - cu_num_draft_tokens, - draft_token_ids, target_argmax, - bonus_token_ids, is_greedy, - max_spec_len, grid, block_size) + rejection_greedy_sample_with_triton( + output_token_ids, + num_draft_tokens, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + grid, + block_size, + ) else: - if min(num_draft_tokens) == 1 and max( - num_draft_tokens) == 1 and sampling_metadata.all_greedy: + if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and sampling_metadata.all_greedy: rejection_greedy_sample_spec_len_1_pytorch( output_token_ids, draft_token_ids, @@ -179,7 +189,7 @@ def rejection_sample( if not using_block_verify: # Rejection sampling for random sampling requests. if HAS_TRITON: - rejection_random_sample_kernel[(grid, )]( + rejection_random_sample_kernel[(grid,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -214,7 +224,7 @@ def rejection_sample( else: # MagicMTP: Improving acceptance rate with Block Verify. if HAS_TRITON: - rejection_random_sample_block_verify_kernel[(grid, )]( + rejection_random_sample_block_verify_kernel[(grid,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -231,19 +241,20 @@ def rejection_sample( BLOCK_SIZE=block_size, ) else: - rejection_random_sample_block_verify_pytorch(output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - IS_NGRAM=draft_probs - is None) + rejection_random_sample_block_verify_pytorch( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + ) return output_token_ids @@ -277,13 +288,7 @@ def expand_batch_to_tokens( assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) if HAS_TRITON: - expand_triton(batch_size, - expanded_x, - x, - cu_num_tokens, - replace_from, - replace_to, - max_num_tokens=MAX_SPEC_LEN) + expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens=MAX_SPEC_LEN) else: expand_pytorch( expanded_x, @@ -301,7 +306,7 @@ def sample_recovered_tokens( num_draft_tokens: list[int], cu_num_draft_tokens: torch.Tensor, draft_token_ids: torch.Tensor, - draft_probs: Optional[torch.Tensor], + draft_probs: torch.Tensor | None, target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, device: torch.device, @@ -316,9 +321,7 @@ def sample_recovered_tokens( ) q.exponential_() - num_draft_tensor = torch.tensor(num_draft_tokens, - pin_memory=True).to(device, - non_blocking=True) + num_draft_tensor = torch.tensor(num_draft_tokens, pin_memory=True).to(device, non_blocking=True) has_draft_mask = num_draft_tensor > 0 for i, generator in sampling_metadata.generators.items(): @@ -357,10 +360,10 @@ def sample_recovered_tokens( def rejection_greedy_sample_spec_len_1_pytorch( - output_token_ids, # [batch_size, 2] - draft_token_ids, # [num_tokens] - target_argmax, # [num_tokens] - bonus_token_ids, # [batch_size] + output_token_ids, # [batch_size, 2] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] ): batch_size = output_token_ids.size(0) num_tokens = draft_token_ids.size(0) @@ -368,73 +371,56 @@ def rejection_greedy_sample_spec_len_1_pytorch( accept_req_mask = draft_token_ids == target_argmax output_token_ids[:, 0] = target_argmax bonus_token_ids = bonus_token_ids.squeeze(1) - output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, - output_token_ids[:, 1]) + output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids, output_token_ids[:, 1]) def rejection_greedy_sample_pytorch( - output_token_ids, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens, # [batch_size] - draft_token_ids, # [num_tokens] - target_argmax, # [num_tokens] - bonus_token_ids, # [batch_size] - draft_tokens_per_req, # [batch_size], list - max_spec_len, - is_greedy=None, # [batch_size] or None + output_token_ids, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens, # [batch_size] + draft_token_ids, # [num_tokens] + target_argmax, # [num_tokens] + bonus_token_ids, # [batch_size] + draft_tokens_per_req, # [batch_size], list + max_spec_len, + is_greedy=None, # [batch_size] or None ): batch_size = output_token_ids.size(0) num_tokens = draft_token_ids.size(0) device = output_token_ids.device - draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to( - device, non_blocking=True) + draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(device, non_blocking=True) if is_greedy is None: is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device) start_indices = cu_num_draft_tokens - draft_tokens_per_req req_ids = torch.arange(batch_size, device=device) token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req) - token_positions = torch.arange( - num_tokens, device=device) - start_indices[token_req_ids] + token_positions = torch.arange(num_tokens, device=device) - start_indices[token_req_ids] # Find the first mismatch position of each request. - mismatch_global = (draft_token_ids != target_argmax) + mismatch_global = draft_token_ids != target_argmax if max_spec_len == 0: - first_mismatch_pos_per_req = torch.zeros(batch_size, - dtype=torch.long, - device=device) + first_mismatch_pos_per_req = torch.zeros(batch_size, dtype=torch.long, device=device) else: # [bs, max_spec_len] - pos_matrix = torch.full((batch_size, max_spec_len), - -1, - dtype=torch.long, - device=device) + pos_matrix = torch.full((batch_size, max_spec_len), -1, dtype=torch.long, device=device) pos_matrix[token_req_ids, token_positions] = token_positions - mismatch_matrix = torch.full((batch_size, max_spec_len), - False, - dtype=torch.bool, - device=device) + mismatch_matrix = torch.full((batch_size, max_spec_len), False, dtype=torch.bool, device=device) mismatch_matrix[token_req_ids, token_positions] = mismatch_global - mismatch_positions = torch.where(mismatch_matrix, pos_matrix, - max_spec_len * 2) + mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2) first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1) - no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2) - first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[ - no_mismatch_mask] + no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2 + first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[no_mismatch_mask] # Copy matched target tokens into output. - copy_len = torch.minimum(first_mismatch_pos_per_req + 1, - draft_tokens_per_req) - copy_indices = torch.arange(max_spec_len + 1, - device=device).expand(batch_size, -1) + copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req) + copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1) copy_mask = copy_indices < copy_len.unsqueeze(1) greedy_mask = is_greedy.unsqueeze(1) final_copy_mask = copy_mask & greedy_mask global_idx = start_indices.unsqueeze(1) + copy_indices - output_token_ids[final_copy_mask] = target_argmax[ - global_idx[final_copy_mask]].to(output_token_ids.dtype) + output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(output_token_ids.dtype) # Fill bonus token. - needs_bonus = is_greedy & (first_mismatch_pos_per_req - >= draft_tokens_per_req) + needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req) if torch.any(needs_bonus): bonus_rows = torch.where(needs_bonus)[0] bonus_cols = draft_tokens_per_req[bonus_rows] @@ -458,24 +444,24 @@ def rejection_random_sample_pytorch( ): """ This function implements the Speculative Decoding rejection sampling step. - Instead of looping through each request and each token (which causes high + Instead of looping through each request and each token (which causes high overhead), it uses a fully vectorized approach: - - 1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D - [batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle + + 1. **Index Mapping**: Converts the flattened 1D token arrays into a 2D + [batch_size, max_draft_len] grid using 'cu_num_draft_tokens' to handle variable-length sequences in the batch. - 2. **Parallel Validation**: Calculates the acceptance condition - (target_prob / draft_prob >= uniform_sample) for ALL draft tokens + 2. **Parallel Validation**: Calculates the acceptance condition + (target_prob / draft_prob >= uniform_sample) for ALL draft tokens simultaneously across the entire batch. - 3. **Short-circuit Simulation**: In the loop version, once a token is rejected, - subsequent tokens are ignored. Here, we simulate this by finding the - 'first_reject_pos' using argmax on the rejection mask and creating a + 3. **Short-circuit Simulation**: In the loop version, once a token is rejected, + subsequent tokens are ignored. Here, we simulate this by finding the + 'first_reject_pos' using argmax on the rejection mask and creating a 'should_skip' mask for all indices after the first failure. 4. **Token Selection**: Uses 'torch.where' to select: - Draft tokens (if accepted) - Recovered tokens (at the point of first rejection) - Bonus tokens (if all tokens in a sequence were accepted) - 5. **Masking**: Ensures operations only apply to non-greedy requests and + 5. **Masking**: Ensures operations only apply to non-greedy requests and within valid sequence lengths. """ @@ -495,15 +481,12 @@ def rejection_random_sample_pytorch( valid_mask = pos_indices < num_draft_per_batch[:, None] global_token_indices = cu_start[:, None] + pos_indices - global_token_indices = global_token_indices.clamp( - 0, draft_token_ids.shape[0] - 1) - draft_tokens = draft_token_ids[ - global_token_indices] # [batch_size, max_draft_len] + global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1) + draft_tokens = draft_token_ids[global_token_indices] # [batch_size, max_draft_len] if IS_NGRAM: ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) - draft_token_probs = ones_cpu.to( - device, non_blocking=True).expand_as(draft_tokens) + draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens) else: flat_indices = global_token_indices.flatten() flat_draft_tokens = draft_tokens.flatten() @@ -518,24 +501,21 @@ def rejection_random_sample_pytorch( uniform_token_probs = uniform_probs[global_token_indices] recovered_tokens = recovered_token_ids[global_token_indices] - zero_threshold_cpu = torch.tensor([0.0], - pin_memory=True, - dtype=torch.float32) + zero_threshold_cpu = torch.tensor([0.0], pin_memory=True, dtype=torch.float32) zero_threshold = zero_threshold_cpu.to(device, non_blocking=True) acceptance_condition = (draft_token_probs > zero_threshold) & ( - target_token_probs / draft_token_probs >= uniform_token_probs) + target_token_probs / draft_token_probs >= uniform_token_probs + ) first_rejection = (~acceptance_condition) & valid_mask - default_pos_cpu = torch.full([batch_size, 1], - max_draft_len, - pin_memory=True) + default_pos_cpu = torch.full([batch_size, 1], max_draft_len, pin_memory=True) default_pos = default_pos_cpu.to(device, non_blocking=True) first_reject_pos = torch.where( - first_rejection.any(dim=1, keepdim=True), - first_rejection.float().argmax(dim=1, keepdim=True), default_pos) + first_rejection.any(dim=1, keepdim=True), first_rejection.float().argmax(dim=1, keepdim=True), default_pos + ) pos_mask = pos_indices >= first_reject_pos should_skip = pos_mask & valid_mask @@ -543,16 +523,17 @@ def rejection_random_sample_pytorch( non_greedy_mask = ~is_greedy update_mask = non_greedy_mask[:, None] & valid_mask & (~should_skip) - first_reject_mask = (pos_indices == first_reject_pos - ) & valid_mask & non_greedy_mask[:, None] + first_reject_mask = (pos_indices == first_reject_pos) & valid_mask & non_greedy_mask[:, None] final_update_mask = update_mask | first_reject_mask final_tokens = torch.where( - first_reject_mask, recovered_tokens, - torch.where(final_acceptance, draft_tokens, - output_token_ids[:, :max_draft_len])) + first_reject_mask, + recovered_tokens, + torch.where(final_acceptance, draft_tokens, output_token_ids[:, :max_draft_len]), + ) output_token_ids[:, :max_draft_len] = torch.where( - final_update_mask, final_tokens, output_token_ids[:, :max_draft_len]) + final_update_mask, final_tokens, output_token_ids[:, :max_draft_len] + ) no_rejection = first_reject_pos.squeeze(1) >= num_draft_per_batch should_add_bonus = non_greedy_mask & no_rejection @@ -561,8 +542,7 @@ def rejection_random_sample_pytorch( seq_len = output_token_ids.shape[1] all_positions_cpu = torch.arange(seq_len, pin_memory=True) - all_positions = all_positions_cpu.to( - device, non_blocking=True)[None, :] # [1, seq_len] + all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] # [1, seq_len] batch_bonus_positions = bonus_positions[:, None] # [batch_size, 1] @@ -572,12 +552,11 @@ def rejection_random_sample_pytorch( valid_bonus_pos = bonus_positions < (max_spec_len_device + 1) final_bonus_mask = should_add_bonus & valid_bonus_pos - bonus_pos_match = (all_positions == batch_bonus_positions) + bonus_pos_match = all_positions == batch_bonus_positions bonus_pos_mask = bonus_pos_match & final_bonus_mask[:, None] bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, seq_len) - output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, - output_token_ids) + output_token_ids[:] = torch.where(bonus_pos_mask, bonus_values_expanded, output_token_ids) def expand_pytorch( @@ -589,17 +568,17 @@ def expand_pytorch( MAX_NUM_TOKENS, ): """ - This function broadcasts batch-level values (input_ptr) to token-level - positions (output_ptr) based on cumulative token offsets. It acts like + This function broadcasts batch-level values (input_ptr) to token-level + positions (output_ptr) based on cumulative token offsets. It acts like a "scatter" or "repeat_interleave" operation but with custom logic: - - 1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size - [num_tokens, batch_size] that identifies which batch index each token + + 1. **Range Broadcasting**: It creates a boolean matrix 'in_range' of size + [num_tokens, batch_size] that identifies which batch index each token belongs to by checking if the token index falls between cu_start and cu_end. - 2. **Conditional Replacement**: Before expansion, it replaces specific values + 2. **Conditional Replacement**: Before expansion, it replaces specific values (e.g., padding or special markers) in the input to prepare the data. - 3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted - sum that effectively "picks" the correct batch value for every token position + 3. **Matrix-based Mapping**: It uses 'torch.einsum' to perform a weighted + sum that effectively "picks" the correct batch value for every token position simultaneously, avoiding a Python loop over the batch. """ device = cu_num_tokens_ptr.device @@ -609,21 +588,16 @@ def expand_pytorch( if batch_size == 0 or num_tokens == 0: return - cu_start = torch.cat([ - torch.tensor([0], pin_memory=True).to(device, non_blocking=True), - cu_num_tokens_ptr[:-1] - ]) + cu_start = torch.cat([torch.tensor([0], pin_memory=True).to(device, non_blocking=True), cu_num_tokens_ptr[:-1]]) cu_end = cu_num_tokens_ptr - token_indices = torch.arange(num_tokens, - device=device)[:, None] # [num_tokens, 1] + token_indices = torch.arange(num_tokens, device=device)[:, None] # [num_tokens, 1] cu_start_exp = cu_start[None, :] # [1, batch_size] cu_end_exp = cu_end[None, :] # [1, batch_size] in_range = (token_indices >= cu_start_exp) & (token_indices < cu_end_exp) - replaced_input = torch.where(input_ptr == replace_from, replace_to, - input_ptr).float() + replaced_input = torch.where(input_ptr == replace_from, replace_to, input_ptr).float() token_values = torch.einsum("tb,b->t", in_range.float(), replaced_input) @@ -643,21 +617,21 @@ def sample_recovered_tokens_pytorch( IS_NGRAM=False, ): """ - When a draft token is rejected, we must sample a "recovered" token from - a modified distribution. This function calculates that distribution across + When a draft token is rejected, we must sample a "recovered" token from + a modified distribution. This function calculates that distribution across the entire flattened batch. - - 1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it - determines which request in the batch each token belongs to. This is + + 1. **Token-to-Batch Mapping**: Using the cumulative draft token counts, it + determines which request in the batch each token belongs to. This is necessary because 'q' (normalization factor) is stored per-request. - 2. **Probability Adjustment**: + 2. **Probability Adjustment**: - If N-GRAM: It zeroes out the draft token's probability in the target. - - If Probabilistic: It calculates max(0, target_probs - draft_probs) + - If Probabilistic: It calculates max(0, target_probs - draft_probs) as per the standard speculative decoding algorithm. - 3. **Normalization & Sampling**: It divides the adjusted probabilities - by the normalization distribution 'q'. To remain vectorized, it + 3. **Normalization & Sampling**: It divides the adjusted probabilities + by the normalization distribution 'q'. To remain vectorized, it broadcasts 'q' from [batch_size, vocab] to [num_tokens, vocab]. - 4. **Argmax Selection**: It selects the best recovery token for every + 4. **Argmax Selection**: It selects the best recovery token for every position in one pass using torch.argmax. """ device = output_token_ids.device @@ -666,10 +640,12 @@ def sample_recovered_tokens_pytorch( if num_tokens == 0: return - cu_start = torch.cat([ - torch.tensor([0], pin_memory=True).to(device, non_blocking=True), - cu_num_draft_tokens[:-1], - ]) + cu_start = torch.cat( + [ + torch.tensor([0], pin_memory=True).to(device, non_blocking=True), + cu_num_draft_tokens[:-1], + ] + ) cu_end = cu_num_draft_tokens token_indices = torch.arange(num_tokens, device=device) # [num_tokens] @@ -678,8 +654,7 @@ def sample_recovered_tokens_pytorch( cu_start_expanded = cu_start[None, :] # [1, batch_size] cu_end_expanded = cu_end[None, :] # [1, batch_size] - in_range_mask = (token_indices_expanded >= cu_start_expanded) & ( - token_indices_expanded < cu_end_expanded) + in_range_mask = (token_indices_expanded >= cu_start_expanded) & (token_indices_expanded < cu_end_expanded) token_to_batch = torch.argmax(in_range_mask.int(), dim=1) @@ -707,8 +682,7 @@ def sample_recovered_tokens_pytorch( prob_over_q = prob / q_values_safe - prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, - prob_over_q) + prob_over_q = torch.where((q_values == 0) | torch.isinf(q_values), -1e10, prob_over_q) recovered_ids = torch.argmax(prob_over_q, dim=1) @@ -742,14 +716,12 @@ def rejection_random_sample_block_verify_pytorch( pos_indices = pos_indices_cpu.to(device, non_blocking=True)[None, :] valid_mask = pos_indices < num_draft_per_batch global_token_indices = cu_start[:, None] + pos_indices - global_token_indices = global_token_indices.clamp( - 0, draft_token_ids.shape[0] - 1) + global_token_indices = global_token_indices.clamp(0, draft_token_ids.shape[0] - 1) draft_tokens = draft_token_ids[global_token_indices] if IS_NGRAM: ones_cpu = torch.ones(1, pin_memory=True, dtype=torch.float32) - draft_token_probs = ones_cpu.to( - device, non_blocking=True).expand_as(draft_tokens) + draft_token_probs = ones_cpu.to(device, non_blocking=True).expand_as(draft_tokens) else: flat_indices = global_token_indices.flatten() flat_draft_tokens = draft_tokens.flatten() @@ -772,27 +744,21 @@ def rejection_random_sample_block_verify_pytorch( last_accept_pos = torch.where( legal_mask.any(dim=-1, keepdim=True), - (max_spec_len - - legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1), - -1) + (max_spec_len - legal_mask.flip(dims=[-1]).float().argmax(dim=-1, keepdim=True) - 1), + -1, + ) non_greedy_mask = (~is_greedy)[:, None] - accept_mask = (pos_indices - <= last_accept_pos) & valid_mask & non_greedy_mask - output_token_ids[:, :max_spec_len] = torch.where( - accept_mask, draft_tokens, output_token_ids[:, :max_spec_len]) + accept_mask = (pos_indices <= last_accept_pos) & valid_mask & non_greedy_mask + output_token_ids[:, :max_spec_len] = torch.where(accept_mask, draft_tokens, output_token_ids[:, :max_spec_len]) - reject_mask = (pos_indices - == last_accept_pos + 1) & valid_mask & non_greedy_mask - output_token_ids[:, :max_spec_len] = torch.where( - reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len]) + reject_mask = (pos_indices == last_accept_pos + 1) & valid_mask & non_greedy_mask + output_token_ids[:, :max_spec_len] = torch.where(reject_mask, recovered_tokens, output_token_ids[:, :max_spec_len]) bonus_mask = (last_accept_pos + 1 >= num_draft_per_batch) & non_greedy_mask all_positions_cpu = torch.arange(max_spec_len + 1, pin_memory=True) all_positions = all_positions_cpu.to(device, non_blocking=True)[None, :] - bonus_pos_match = (all_positions == num_draft_per_batch) + bonus_pos_match = all_positions == num_draft_per_batch bonus_mask = bonus_mask & bonus_pos_match - bonus_values_expanded = bonus_token_ids.view(-1, 1).expand( - -1, max_spec_len + 1) - output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, - output_token_ids) + bonus_values_expanded = bonus_token_ids.view(-1, 1).expand(-1, max_spec_len + 1) + output_token_ids[:] = torch.where(bonus_mask, bonus_values_expanded, output_token_ids) diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 62410cdb..40cc496d 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -35,7 +35,6 @@ def random_sample( class AscendSampler(Sampler): - def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE): # TODO: support logprobs_mode in vllm-ascend super().__init__(logprobs_mode=logprobs_mode) @@ -62,7 +61,6 @@ class AscendSampler(Sampler): class AscendTopKTopPSampler(TopKTopPSampler): - def __init__(self, **kwargs): super().__init__(**kwargs) self.apply_top_k_top_p = apply_top_k_top_p @@ -135,4 +133,9 @@ def _apply_top_k_top_p_ascendc( return logits return torch.ops._C_ascend.npu_apply_top_k_top_p(logits, k=k, p=p) -apply_top_k_top_p = _apply_top_k_top_p_ascendc if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] else _apply_top_k_top_p_pytorch \ No newline at end of file + +apply_top_k_top_p = ( + _apply_top_k_top_p_ascendc + if get_ascend_device_type() in [AscendDeviceType.A2, AscendDeviceType.A3] + else _apply_top_k_top_p_pytorch +) diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 9fba9b52..4ffc7df6 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import numpy as np import torch from vllm.distributed import get_dcp_group, get_pcp_group @@ -8,17 +6,18 @@ from vllm.v1.utils import CpuGpuBuffer class BlockTable: - - def __init__(self, - block_size: int, - max_num_reqs: int, - max_num_blocks_per_req: int, - max_num_batched_tokens: int, - pin_memory: bool, - device: torch.device, - kernel_sizes: Union[list[int], None] = None, - cp_kv_cache_interleave_size: int = 1, - num_speculative_tokens: int = 0): + def __init__( + self, + block_size: int, + max_num_reqs: int, + max_num_blocks_per_req: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + kernel_sizes: list[int] | None = None, + cp_kv_cache_interleave_size: int = 1, + num_speculative_tokens: int = 0, + ): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens @@ -28,8 +27,7 @@ class BlockTable: try: self.pcp_world_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_world_size > 1 else 0 + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_world_size > 1 else 0 self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: @@ -49,42 +47,37 @@ class BlockTable: # Find the first kernel size that divides physical_block_size evenly selected_kernel_size = None for kernel_size in kernel_sizes: - if kernel_size > 0 \ - and self.physical_block_size % kernel_size == 0: + if kernel_size > 0 and self.physical_block_size % kernel_size == 0: selected_kernel_size = kernel_size break if selected_kernel_size is None: raise ValueError( f"None of the kernel sizes {kernel_sizes} can divide " - f"physical block size {self.physical_block_size} evenly") + f"physical block size {self.physical_block_size} evenly" + ) self.block_size = selected_kernel_size self.logical_block_size = selected_kernel_size - self.blocks_per_phys_block = (self.physical_block_size // - self.logical_block_size) + self.blocks_per_phys_block = self.physical_block_size // self.logical_block_size if self.blocks_per_phys_block > 1: self.use_hybrid_blocks = True else: self.use_hybrid_blocks = False if self.use_hybrid_blocks: - logical_table_size = (max_num_blocks_per_req * - self.blocks_per_phys_block) + logical_table_size = max_num_blocks_per_req * self.blocks_per_phys_block else: logical_table_size = max_num_blocks_per_req duplicate_size = 1 if self.pcp_world_size * self.dcp_world_size > 1: duplicate_size += num_speculative_tokens - self.block_table = self._make_buffer(max_num_reqs * duplicate_size, - logical_table_size, - dtype=torch.int32) + self.block_table = self._make_buffer(max_num_reqs * duplicate_size, logical_table_size, dtype=torch.int32) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) self.slot_mapping = self._make_buffer( - self.max_num_batched_tokens + - 2 * self.pcp_world_size * self.max_num_reqs, - dtype=torch.int32) + self.max_num_batched_tokens + 2 * self.pcp_world_size * self.max_num_reqs, dtype=torch.int32 + ) self.kernel_sizes = kernel_sizes self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size @@ -103,7 +96,7 @@ class BlockTable: num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] - self.block_table.np[row_idx, start:start + num_blocks] = block_ids + self.block_table.np[row_idx, start : start + num_blocks] = block_ids self.num_blocks_per_row[row_idx] += num_blocks def add_row(self, block_ids: list[int], row_idx: int) -> None: @@ -112,8 +105,7 @@ class BlockTable: def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] - self.block_table.np[tgt, :num_blocks] = self.block_table.np[ - src, :num_blocks] + self.block_table.np[tgt, :num_blocks] = self.block_table.np[src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks def swap_row(self, src: int, tgt: int) -> None: @@ -124,8 +116,7 @@ class BlockTable: self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]] - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None: # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. @@ -150,27 +141,30 @@ class BlockTable: # (always needed with unified tensor) # Each physical block is split into multiple logical blocks # The logical table has been expanded to accommodate this - block_table_indices = (req_indices * self.max_num_blocks_per_req * - self.blocks_per_phys_block + - logical_block_idx) + block_table_indices = ( + req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx + ) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank - mask = (virtual_block_offsets // self.cp_kv_cache_interleave_size % - (self.dcp_world_size * - self.pcp_world_size) == self.current_rank) + mask = ( + virtual_block_offsets // self.cp_kv_cache_interleave_size % (self.dcp_world_size * self.pcp_world_size) + == self.current_rank + ) # Calculate local block_offsets - block_offsets = virtual_block_offsets \ - // (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) \ - * self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size + block_offsets = ( + virtual_block_offsets + // (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) + * self.cp_kv_cache_interleave_size + + virtual_block_offsets % self.cp_kv_cache_interleave_size + ) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local - self.slot_mapping.np[:req_indices.shape[0]] = np.where( - mask, slot_mapping, -1) + self.slot_mapping.np[: req_indices.shape[0]] = np.where(mask, slot_mapping, -1) else: assert self.kernel_sizes is not None if self.block_size == self.kernel_sizes[0]: @@ -183,15 +177,12 @@ class BlockTable: # Each physical block is split into multiple logical blocks # The logical table has been expanded to accommodate this block_table_indices = ( - req_indices * self.max_num_blocks_per_req * - self.blocks_per_phys_block + logical_block_idx) + req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx + ) - block_numbers = self.block_table.np.ravel( - )[block_table_indices] + block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping.np[:req_indices.shape[0]]) + np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping.np[: req_indices.shape[0]]) def commit_block_table(self, num_reqs: int) -> None: self.block_table.copy_to_gpu(num_reqs) @@ -203,8 +194,7 @@ class BlockTable: self.block_table.fill_(0) self.block_table.cpu.fill_(0) - def _convert_physical_to_logical_blocks( - self, physical_blocks: np.ndarray) -> np.ndarray: + def _convert_physical_to_logical_blocks(self, physical_blocks: np.ndarray) -> np.ndarray: """Convert physical block IDs to logical block IDs.""" if not self.use_hybrid_blocks: return physical_blocks @@ -217,8 +207,7 @@ class BlockTable: # [1*split_ratio, 1*split_ratio+1, ...] # But we need to account for the fact that block 0 is special base_logical = phys_block * self.blocks_per_phys_block - logical_blocks.extend( - range(base_logical, base_logical + self.blocks_per_phys_block)) + logical_blocks.extend(range(base_logical, base_logical + self.blocks_per_phys_block)) return np.array(logical_blocks, dtype=np.int32) @@ -234,27 +223,25 @@ class BlockTable: """Returns the numpy array of the block table.""" return self.block_table.np - def _make_buffer(self, *size: int | torch.SymInt, - dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory) + def _make_buffer(self, *size: int | torch.SymInt, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory) class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" - def __init__(self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - pin_memory: bool, - device: torch.device, - block_sizes: list[int], - num_speculative_tokens: int = 0, - kernel_sizes: Optional[list[list[int]]] = None, - cp_kv_cache_interleave_size: int = 1) -> None: + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0, + kernel_sizes: list[list[int]] | None = None, + cp_kv_cache_interleave_size: int = 1, + ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -274,24 +261,26 @@ class MultiGroupBlockTable: kernel_sizes = kernel_sizes * len(block_sizes) elif len(kernel_sizes) != len(block_sizes): raise ValueError( - f"kernel_sizes length ({len(kernel_sizes)}) must match " - f"block_sizes length ({len(block_sizes)})") + f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})" + ) # Use zip to pair block_sizes with kernel_sizes one-to-one self.block_tables = [ BlockTable( - block_size, max_num_reqs, - max( - cdiv(max_model_len, - block_size * dcp_world_size * pcp_world_size), - 1 + num_speculative_tokens), max_num_batched_tokens, - pin_memory, device, kernel_size_list, - cp_kv_cache_interleave_size, num_speculative_tokens) + block_size, + max_num_reqs, + max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens), + max_num_batched_tokens, + pin_memory, + device, + kernel_size_list, + cp_kv_cache_interleave_size, + num_speculative_tokens, + ) for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) ] - def append_row(self, block_ids: tuple[list[int], ...], - row_idx: int) -> None: + def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) @@ -307,8 +296,7 @@ class MultiGroupBlockTable: for block_table in self.block_tables: block_table.swap_row(src, tgt) - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None: for block_table in self.block_tables: block_table.compute_slot_mapping(req_indices, positions)