diff --git a/.github/workflows/vllm-dependency-test.yml b/.github/workflows/vllm-dependency-test.yml new file mode 100644 index 000000000..96a2b362e --- /dev/null +++ b/.github/workflows/vllm-dependency-test.yml @@ -0,0 +1,45 @@ +name: VLLM Dependency Test + +on: + push: + branches: [ main ] + paths: + - "python/pyproject.toml" + - "python/sglang/**" + - "test/**" + - "docs/**" + - "scripts/**" + pull_request: + branches: [ main ] + paths: + - "python/pyproject.toml" + - "python/sglang/**" + - "test/**" + - "docs/**" + - "scripts/**" + +concurrency: + group: vllm-dependency-test-${{ github.ref }} + cancel-in-progress: true + +jobs: + vllm-dependency-test: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + env: + FLASHINFER_REPO: 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' + run: | + bash scripts/ci_install_dependency.sh + pip install "vllm>=0.6.4.post1,<=0.7.2" + + - name: Run VLLM dependency tests + timeout-minutes: 60 + run: | + cd test/srt + python3 run_suite.py --suite vllm_dependency_test --timeout-per-file 3600 diff --git a/python/pyproject.toml b/python/pyproject.toml index 3a682804e..5593ef6e1 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -47,7 +47,6 @@ srt = [ "sgl-kernel==0.0.5.post3", "flashinfer_python==0.2.3", "torch==2.5.1", - "vllm>=0.6.4.post1,<=0.7.2", "cuda-python", "outlines>=0.0.44,<=0.1.11", ] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index e0df392dd..5941264f0 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -22,7 +22,11 @@ import torch from transformers import PretrainedConfig from sglang.srt.hf_transformers_utils import get_config, get_context_length -from sglang.srt.layers.quantization import QUANTIZATION_METHODS +from sglang.srt.layers.quantization import ( + BASE_QUANTIZATION_METHODS, + QUANTIZATION_METHODS, + VLLM_AVAILABLE, +) from sglang.srt.utils import get_bool_env_var, is_hip logger = logging.getLogger(__name__) @@ -235,7 +239,12 @@ class ModelConfig: # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py def _verify_quantization(self) -> None: - supported_quantization = [*QUANTIZATION_METHODS] + # Select supported quantization methods based on vllm availability + if VLLM_AVAILABLE: + supported_quantization = [*QUANTIZATION_METHODS] + else: + supported_quantization = [*BASE_QUANTIZATION_METHODS] + rocm_supported_quantization = [ "awq", "gptq", @@ -273,7 +282,11 @@ class ModelConfig: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for _, method in QUANTIZATION_METHODS.items(): + # Only iterate through currently available quantization methods + available_methods = ( + QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS + ) + for _, method in available_methods.items(): quantization_override = method.override_quantization_method( quant_cfg, self.quantization ) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index eed6125a9..ced337205 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1316,7 +1316,10 @@ vllm_get_world_group = None def monkey_patch_vllm_parallel_state(reverse: bool = False): - import vllm.distributed.parallel_state as vllm_parrlel_state + try: + import vllm.distributed.parallel_state as vllm_parrlel_state + except ImportError: + return global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group if vllm_get_pp_group is None: diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 32bcf1572..c4627082f 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import ( PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter, + _ColumnvLLMParameter, ) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase): assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - from sglang.srt.layers.parameter import _ColumnvLLMParameter - if isinstance(param, _ColumnvLLMParameter): param.load_column_parallel_weight( loaded_weight, @@ -1247,7 +1246,7 @@ class RowParallelLinear(LinearBase): assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - if isinstance(param, BasevLLMParameter): + if isinstance(param, RowvLLMParameter): # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py, # It supports additional parameters like tp_rank and use_presharded_weights. param.load_row_parallel_weight( diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 97299baa2..1766e2c25 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -8,7 +8,6 @@ from typing import Callable, Optional import torch from torch.nn import functional as F -from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.moe.topk import select_experts @@ -69,6 +68,8 @@ def moe_forward_native( activation: str = "silu", ) -> torch.Tensor: + from sglang.srt.layers.activation import GeluAndMul, SiluAndMul + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 42f697fbf..da16e1680 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -305,6 +305,7 @@ class FusedMoE(torch.nn.Module): self.use_presharded_weights = use_presharded_weights self.inplace = inplace self.no_combine = no_combine + self.local_num_experts = num_experts if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -629,8 +630,6 @@ class FusedMoE(torch.nn.Module): custom_routing_function=self.custom_routing_function, correction_bias=self.correction_bias, activation=self.activation, - inplace=self.inplace, - no_combine=self.no_combine, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 378eef795..c04e06d2e 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,11 +17,12 @@ from typing import Callable, Optional import torch import torch.nn.functional as F -from sglang.srt.utils import get_compiler_backend, is_cuda +from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() +_is_hip = is_hip() -from sglang.srt.managers.utils import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder expert_distribution_recorder = ExpertDistributionRecorder() @@ -53,10 +54,10 @@ def fused_topk( topk: int, renormalize: bool, ): - if _is_cuda: + if _is_cuda or _is_hip: from sgl_kernel import topk_softmax else: - from vllm import _custom_ops as ops + from vllm import _custom_ops as vllm_ops assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -70,7 +71,7 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - if _is_cuda: + if _is_cuda or _is_hip: topk_softmax( topk_weights, topk_ids, @@ -78,7 +79,7 @@ def fused_topk( gating_output.float(), ) else: - ops.topk_softmax( + vllm_ops.topk_softmax( topk_weights, topk_ids, token_expert_indicies, diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 8de731420..1da678c58 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -12,9 +12,6 @@ try: from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig - from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( - CompressedTensorsConfig, - ) from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config @@ -26,6 +23,8 @@ try: from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig + from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig + VLLM_AVAILABLE = True except ImportError: VLLM_AVAILABLE = False @@ -44,8 +43,10 @@ except ImportError: from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, +) from sglang.srt.layers.quantization.fp8 import Fp8Config -from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -55,10 +56,9 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "fp8": Fp8Config, "blockwise_int8": BlockInt8Config, "modelopt": ModelOptFp8Config, - "gptq_marlin": GPTQMarlinConfig, - "gptq": GPTQConfig, "w8a8_int8": W8A8Int8Config, "w8a8_fp8": W8A8Fp8Config, + "compressed-tensors": CompressedTensorsConfig, } # Add vllm-dependent methods if available @@ -74,10 +74,11 @@ if VLLM_AVAILABLE: "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, "awq_marlin": AWQMarlinConfig, - "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, + "gptq_marlin": GPTQMarlinConfig, + "gptq": GPTQConfig, } QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS) diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index e45dda2cc..6058702c9 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC): class QuantizationConfig(ABC): """Base class for quantization configs.""" + def __init__(self): + super().__init__() + # mapping is updated by models as they initialize + self.packed_modules_mapping: Dict[str, List[str]] = dict() + @abstractmethod def get_name(self) -> str: """Name of the quantization method.""" diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/README.md b/python/sglang/srt/layers/quantization/compressed_tensors/README.md new file mode 100644 index 000000000..e8caf538b --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/README.md @@ -0,0 +1,6 @@ +# quantization compressed_tensors module + +To support compressed_tensors format quantization models, we adapted https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors into SGLang. + + +For practical purposes, we have only applied the compressed_tensors format of `w8a8_fp8`. If you have requirements for other formats, you can submit an issue through this [link](https://github.com/sgl-project/sglang/issues). diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 000000000..e056ce95f --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,652 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +import logging +from contextlib import suppress +from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast + +import torch +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) +from pydantic import BaseModel + +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod, +) +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, + CompressedTensorsW8A8Fp8, +) +from sglang.srt.layers.quantization.compressed_tensors.utils import ( + find_matched_target, + is_activation_quantization_format, + should_ignore_layer, +) +from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod + +logger = logging.getLogger(__name__) + +__all__ = ["CompressedTensorsLinearMethod"] + +SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config" +QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]] + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__( + self, + target_scheme_map: Dict[str, Any], + ignore: List[str], + quant_format: str, + sparsity_scheme_map: Dict[str, SparsityCompressionConfig], + sparsity_ignore_list: List[str], + kv_cache_scheme: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.ignore = ignore + self.quant_format = quant_format + # Map from [target -> scheme] + self.target_scheme_map = target_scheme_map + self.kv_cache_scheme = kv_cache_scheme + self.sparsity_scheme_map = sparsity_scheme_map + self.sparsity_ignore_list = sparsity_ignore_list + self.config = config + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> str: + return "compressed_tensors" + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + + # Check if the layer is skipped for quantization. + # TODO (@robertgshaw2): support module names + if should_ignore_layer( + prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + ): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + if scheme is None: + return UnquantizedLinearMethod() + layer.scheme = scheme + return CompressedTensorsLinearMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod.get_moe_method(self) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + ignore: List[str] = cast(List[str], config.get("ignore", [])) + quant_format = cast(str, config.get("format")) + target_scheme_map = cls._quantization_scheme_map_from_config(config=config) + sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( + config=config + ) + + return cls( + target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + sparsity_scheme_map=sparsity_scheme_map, + sparsity_ignore_list=sparsity_ignore_list, + config=config, + ) + + @classmethod + def _parse_sparsity_config( + cls, config: Dict[str, Any] + ) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A tuple with two elements + 1. A dictionary mapping target layer names to their corresponding + sparsity_config + 2. A list of layer names to ignore for sparsity + """ + if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)): + return dict(), [] + + sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config) + sparse_scheme_map: Dict[str, SparsityCompressionConfig] = { + target: sparsity_config for target in sparsity_config.targets or list() + } + sparsity_ignore_list = sparsity_config.ignore or list() + return sparse_scheme_map, sparsity_ignore_list + + @classmethod + def _quantization_scheme_map_from_config( + cls, config: Dict[str, Any] + ) -> QUANTIZATION_SCHEME_MAP_TYPE: + """ + :param config: The `quantization_config` dictionary from config.json + :return: A dictionary mapping target layer names to their corresponding + quantization_args for weights and input activations + """ + target_scheme_map: Dict[str, Any] = dict() + quant_format = cast(str, config.get("format")) + + # The quant_config has multiple config_groups, each containing + # an input_activations key with details about how the activations are + # quantized, a weights key indicating how the weights are quantized, + # and a list of targets under the `targets` key, dictating which + # layers are impacted by the quantization details. The quantization + # details follow the structure defined by the QuantizationArgs + # pydantic model, which is used to verify the structure of the + # quant_config and also store the details for later use. + + config_groups = config.get("config_groups", dict()) + for _, quant_config in config_groups.items(): + targets = quant_config.get("targets") + for target in targets: + target_scheme_map[target] = {} + target_scheme_map[target]["weights"] = QuantizationArgs.model_validate( + quant_config.get("weights") + ) + + target_scheme_map[target]["input_activations"] = None + if is_activation_quantization_format(quant_format): + input_activations = quant_config.get("input_activations") + # The only case where we have activation quant supported + # but no input_activations provided in the config + # should be w8a16fp8 w8a16fp8 can also run for cases where + # there is an input_quant but it is ignored + if not input_activations: + assert ( + target_scheme_map[target]["weights"].type + == QuantizationType.FLOAT + ) + else: + target_scheme_map[target]["input_activations"] = ( + QuantizationArgs.model_validate( # noqa: E501 + quant_config.get("input_activations") + ) + ) + return target_scheme_map + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: + capability_tuple = DeviceCapability(*torch.cuda.get_device_capability()) + + if capability_tuple is not None: + capability = capability_tuple.to_int() + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.", + ) + return supported + else: + return False + + def _is_static_tensor_w8a8( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_tensor = ( + weight_strategy + and input_quant.strategy == QuantizationStrategy.TENSOR.value + ) + is_static = not weight_quant.dynamic and not input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_tensor and weight_quant.symmetric and is_static + + def _is_dynamic_token_w8a8( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_token and weight_quant.symmetric and is_dynamic + + def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: + # Confirm weights and activations quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm weight scheme is supported. + is_floating_point = ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT + ) + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + ] + if not ( + is_floating_point + and is_symmetric_weight + and is_static_weight + and is_per_tensor_or_channel_weight + ): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.dynamic: + return True + + # Confirm activation scheme is supported. + is_symmetric_activation = input_quant.symmetric + is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR + return is_symmetric_activation and is_per_tensor_activation + + def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: + # Confirm weights quantized. + if weight_quant is None: + return False + + # Confirm we have floating points. + if weight_quant.type != QuantizationType.FLOAT: + return False + + # Confirm weight scheme is supported. + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + ] + if not ( + is_symmetric_weight + and is_static_weight # noqa: SIM103 + and is_per_tensor_or_channel_weight + ): + return False + + # All conditions satisfied. + return True + + def _is_wNa16_group_channel( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: + input_quant_none = input_quant is None + is_symmetric = weight_quant.symmetric + is_channel_group = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value + ) + is_static = not weight_quant.dynamic + + return is_channel_group and input_quant_none and is_symmetric and is_static + + def _get_scheme_from_parts( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> "CompressedTensorsScheme": + + # Detect If Mixed Precision + if self._is_wNa16_group_channel(weight_quant, input_quant): + if not VLLM_AVAILABLE: + raise ImportError( + "vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm" + ) + if ( + self.quant_format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS + ): + return CompressedTensorsW4A16Sparse24( + strategy=weight_quant.strategy, + num_bits=weight_quant.num_bits, + group_size=weight_quant.group_size, + ) + if ( + self.quant_format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS + ): + return CompressedTensorsWNA16( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder, + ) + + if is_activation_quantization_format(self.quant_format): + if self._is_fp8_w8a8(weight_quant, input_quant): + is_fp8_w8a8_supported = self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), error=False + ) + if is_fp8_w8a8_supported: + return CompressedTensorsW8A8Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=( + input_quant and not input_quant.dynamic + ), + ) + else: + # note: input_quant will be present for converted models; + # will be ignored during inference post loading + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=not input_quant.dynamic, + ) + + # note: input_quant can be None + if self._is_fp8_w8a16(weight_quant, input_quant): + if not VLLM_AVAILABLE: + raise ImportError( + "vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm" + ) + is_static_input_scheme = input_quant and not input_quant.dynamic + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=is_static_input_scheme, + ) + + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=True, + input_symmetric=input_quant.symmetric, + ) + + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False, + input_symmetric=input_quant.symmetric, + ) + + raise NotImplementedError("No compressed-tensors compatible scheme was found.") + + def get_scheme( + self, layer: torch.nn.Module, layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: + """ + compressed-tensors supports non uniform in the following way: + + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + Detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for infernece. + """ + + # Find the "target" in the compressed-tensors config + # that our layer conforms to. + # TODO (@robertgshaw): add compressed-tensors as dep + # so we do not have to re-write these functions + # need to make accelerate optional in ct to do this + + # Will be empty for models with only sparsity + weight_quant = input_quant = None + if self.target_scheme_map: + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys(), + fused_mapping=self.packed_modules_mapping, + ) + + scheme_dict = self.target_scheme_map[matched_target] + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + # Find the sparsity scheme of the layer + # assume that fused layers inerhit first component's sparsity scheme + sparsity_targets = self.sparsity_scheme_map.keys() - set( + self.sparsity_ignore_list + ) + sparsity_scheme: Optional[SparsityCompressionConfig] = None + with suppress(ValueError): + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=sparsity_targets, + fused_mapping=self.packed_modules_mapping, + ) + sparsity_scheme = self.sparsity_scheme_map[matched_target] + + if self.supports_cutlass_24( + weight_quant=weight_quant, + input_quant=input_quant, + sparsity_scheme=sparsity_scheme, + ): + if not VLLM_AVAILABLE: + raise ImportError( + "vllm is not installed, to use CompressedTensors24, please install vllm" + ) + # Have a valid sparsity scheme + # Validate layer is supported by Cutlass 2:4 Kernel + model_compression_config = ( + None + if sparsity_scheme is None or sparsity_scheme.format == "dense" + else self.config + ) + + scheme = CompressedTensors24( + quantized=weight_quant is not None or input_quant is not None, + weight_quant=weight_quant, + input_quant=input_quant, + model_compression_config=model_compression_config, + ) + elif weight_quant is None: + logger.warning_once( + "Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod" + ) + return None + + else: + # Find the quant_scheme + scheme = self._get_scheme_from_parts( # type: ignore + weight_quant=weight_quant, + input_quant=input_quant, + ) + + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name) + return scheme + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + # If no matches, return None + return None + + @staticmethod + def supports_cutlass_24( + weight_quant: Optional[QuantizationArgs], + input_quant: Optional[QuantizationArgs], + sparsity_scheme: Optional[SparsityCompressionConfig] = None, + ) -> bool: + """ + Check if the layer is supported by the Cutlass 2:4 Kernel + Conditions: + - Overarching condition: Sparsity Structure is 2:4 + - Unquantized cases are supported + - Weight only quantization is not-supported + - Supported weight quantization strategies are TENSOR and CHANNEL + - Supported input quantization strategies are TENSOR and TOKEN + - Only 8 bit quantization is supported + + :return: True if the layer is supported by the Cutlass 2:4 Kernel + False otherwise + """ + if sparsity_scheme is None: + return False + + is_valid_sparsity_structure: bool = ( + sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value + ) + + valid_compressors = { + CompressionFormat.dense.value, + CompressionFormat.sparse_24_bitmask.value, + } + + is_valid_sparsity = ( + is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors + ) + + if not is_valid_sparsity: + return False + + # Unquantized cases are supported + if weight_quant is None and input_quant is None: + return True + + # Weight only quantization is not-supported + if weight_quant is not None and input_quant is None: + return False + + supported_weight_quant_strategies = [ + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.CHANNEL.value, + ] + + assert weight_quant is not None + assert input_quant is not None + if weight_quant.strategy not in supported_weight_quant_strategies: + return False + + supported_input_quant_strategies = [ + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.TOKEN.value, + ] + + if input_quant.strategy not in supported_input_quant_strategies: + return False + + return weight_quant.num_bits == input_quant.num_bits == 8 + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 000000000..032ff8b60 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,658 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +import enum +import logging +from enum import Enum +from typing import Callable, List, Optional + +import torch +from compressed_tensors import CompressionFormat +from compressed_tensors.quantization import QuantizationStrategy + +from sglang.srt.layers.moe.fused_moe_triton import ( + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.utils import ( + all_close_1d, + is_cuda, + is_fp8_fnuz, + per_tensor_dequantize, + replace_parameter, +) +from sglang.srt.utils import set_weight_attrs + +_is_cuda = is_cuda() + +if _is_cuda: + from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant +else: + from vllm import _custom_ops as vllm_ops + +try: + import vllm + + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = [ + "CompressedTensorsMoEMethod", + "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsWNA16MoEMethod", +] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + + @staticmethod + def get_moe_method( + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + ) -> "CompressedTensorsMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") + input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + if not VLLM_AVAILABLE: + raise ImportError( + "vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm" + ) + return CompressedTensorsWNA16MoEMethod(quant_config) + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" + ) + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations" + ) + + if not ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales " + "for weights and activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}" + ) + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + if is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + + if _is_cuda: + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + else: + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + ) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder + assert config.symmetric, "Only symmetric quantization is supported for MoE" + + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS + ): + raise ValueError( + "For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}", + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + assert ( + params_dtype == torch.float16 + ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 + + intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + w2_scales_size = ( + intermediate_size_full if load_full_w2 else intermediate_size_per_partition + ) + + self.is_k_full = (not self.actorder) or ( + intermediate_size_per_partition == intermediate_size_full + ) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) + + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]] + ) + return scale_perm, scale_perm_single + + def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int + ): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int + ): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype + ) + for e in range(num_experts): + output[e] = marlin_permute_scales( + s[e], size_k, size_n, group_size, num_bits + ) + return output + + size_k2 = layer.w2_weight_packed.shape[2] + size_k13 = layer.w13_weight_packed.shape[2] + + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( + torch.int32 + ) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e] + ] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + size_k13, + layer.w13_weight_scale.shape[2], + self.group_size, + self.num_bits, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] + * (self.group_size if self.group_size != -1 else self.packed_factor), + size_k2, + self.group_size, + self.num_bits, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + assert activation == "silu", "Only SiLU activation is supported." + if not VLLM_AVAILABLE: + raise ImportError( + "vllm is not installed, to use fused_marlin_moe, please install vllm" + ) + if expert_map is not None: + raise NotImplementedError( + "Expert Parallelism is not supported for " "fused Marlin MoE method." + ) + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + correction_bias=correction_bias, + ) + + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits, + is_k_full=self.is_k_full, + ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 000000000..fafed717c --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 + +__all__ = [ + "CompressedTensorsScheme", + "CompressedTensorsW8A8Fp8", +] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 000000000..3795d0a54 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,56 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py new file mode 100644 index 000000000..6c624a070 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -0,0 +1,162 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +from compressed_tensors.quantization import QuantizationStrategy +from torch.nn import Parameter + +from sglang.srt.layers.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) +from sglang.srt.layers.quantization.fp8_utils import ( + Fp8LinearOp, + maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, +) +from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale + +__all__ = ["CompressedTensorsW8A8Fp8"] + + +class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer) -> None: + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.strategy == QuantizationStrategy.TENSOR: + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + if is_fp8_fnuz(): + input_scale = getattr(layer, "input_scale", None) + + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=max_w_scale, input_scale=input_scale + ) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.strategy == QuantizationStrategy.CHANNEL: + weight = layer.weight + + if is_fp8_fnuz(): + input_scale = getattr(layer, "input_scale", None) + + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=input_scale, + ) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + else: + weight_scale = layer.weight_scale.data + + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + else: + raise ValueError(f"Unknown quantization strategy {self.strategy}") + + # INPUT SCALE + if self.is_static_input_scheme and hasattr(layer, "input_scale"): + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) + else: + layer.input_scale = None + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + maybe_create_device_identity() + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/utils.py b/python/sglang/srt/layers/quantization/compressed_tensors/utils.py new file mode 100644 index 000000000..ddefa5ea3 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/utils.py @@ -0,0 +1,218 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +import re +from types import MappingProxyType +from typing import Iterable, List, Mapping, Optional + +from compressed_tensors import CompressionFormat +from torch.nn import Module + + +def is_activation_quantization_format(format: str) -> bool: + _ACTIVATION_QUANTIZATION_FORMATS = [ + CompressionFormat.naive_quantized.value, + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value, + ] + return format in _ACTIVATION_QUANTIZATION_FORMATS + + +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str] = tuple(), + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}), +) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping and layer_name not in ignore: + shard_proj_names = fused_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore + ) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError( + f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme." + ) + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match( + layer_name=layer_name, targets=ignore + ) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def find_matched_target( + layer_name: Optional[str], + module: Module, + targets: Iterable[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}), +) -> str: + """ + Helper function to look up which "target" in the compressed-tensors + config that a layer corresponds to. + + Recall that a compressed-tensors configs has a concept of + config_groups, where each layer can be quantized with with a different + scheme. + + targets in each config_group will be a list of either layer names + (or regexes corresponding to layer names) or names of torch Modules. + + First, we try to match the layer_name with a target + Second, we try to match the module's name with a target + Third, we try to map the layer_name to a list of fused module names. + *All* component module names must match in order for a match to be + successful. A successful match returns the first component target + + :param layer_name: layer name + :param module: torch.nn.Module + :param targets: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + :param fused_strategy: either "all" or "any". If using "all", fused + layers match if "all" of its components match + """ + + if layer_name is None: + layer_name = "" + + matched_target = ( + _find_first_match(layer_name, targets) + or _find_first_match(module.__class__.__name__, targets, True) + or _match_fused_layer(layer_name, targets, fused_mapping) + ) + + if matched_target is None: + raise ValueError( + f"Unable to find matching target for {layer_name} in the " + "compressed-tensors config." + ) + + return matched_target + + +def _find_first_match( + value: str, targets: Iterable[str], check_contains: bool = False +) -> Optional[str]: + """ + Returns first element of target that matches value either + exactly or as a regex after 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + + :param value: string to compare the list of targets against + :param targets: list of targets to match the layer against + :param check_contains: whether or not to do a substring match + """ + + for target in targets: + if _is_equal_or_regex_match(value, target, check_contains=check_contains): + return target + return None + + +def _is_equal_or_regex_match( + value: str, target: str, check_contains: bool = False +) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False + + +def _match_fused_layer( + layer_name: str, + target_layers: Iterable[str], + fused_mapping: Mapping[str, List[str]], +) -> Optional[str]: + """ + Match a fused layer name to its corresponding individual layer in + target_layers. Returns first value in fused_mapping which matches targets + + Implements an "all" matching strategy where a fused layer matches iff + "all" of its components match + + :param layer_name: layer name + :param target_layers: list of targets to match the layer against + :param fused_mapping: map from fused layer names to its components + + Examples: + layer_name = "model.layers.0.self_attn.qkv_proj" + target_layers = ["model.layers.0.self_attn.q_proj", + "model.layers.0.self_attn.k_proj", + "model.layers.0.self_attn.v_proj"] + """ + # find layer_name in mapping + fused = next((key for key in fused_mapping if layer_name.endswith(key)), None) + if fused is None: + return None + + # expand path of unfused components + unfused_paths = [ + layer_name.replace(fused, unfused) for unfused in fused_mapping[fused] + ] + + # for each unfused component, find a match in targets + unfused_matches: List[Optional[str]] = [] + for unfused in unfused_paths: + for target in target_layers: + if _is_equal_or_regex_match(unfused, target): + unfused_matches.append(target) + break + else: + unfused_matches.append(None) + + return unfused_matches[0] if all(unfused_matches) else None diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index b1a2034b9..9ba62a6f6 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple import torch @@ -18,6 +19,7 @@ from sglang.srt.utils import ( try: import vllm + from vllm import _custom_ops as ops VLLM_AVAILABLE = True except ImportError: @@ -31,19 +33,29 @@ if _is_hip and get_bool_env_var("CK_MOE"): _is_cuda = is_cuda() if _is_cuda: - from sgl_kernel import fp8_blockwise_scaled_mm + from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm + from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 - if use_vllm_cutlass_w8a8_fp8_kernel and VLLM_AVAILABLE: - from vllm import _custom_ops as ops - else: - from sgl_kernel import fp8_scaled_mm - # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) +_TORCH_VERSION = torch.__version__.split("+")[0] +try: + _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3])) +except ValueError: + _TORCH_VERSION_TUPLE = (0, 0, 0) + +# The condition to determine if it is on a platform that supports +# torch._scaled_mm rowwise feature. +# The condition is determined once as the operations +# are time consuming. +USE_ROWWISE_TORCH_SCALED_MM = ( + _is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0) +) + def cutlass_fp8_supported(): if not _is_cuda: @@ -330,3 +342,223 @@ def apply_fp8_linear( if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) + + +def maybe_create_device_identity(): + # Allocate dummy ones tensor for torch._scaled_mm + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +# TODO(luka): follow similar pattern for marlin and block-fp8-linear +# https://github.com/vllm-project/vllm/issues/14397 +class Fp8LinearOp: + """ + This class executes a FP8 linear layer using cutlass if supported and + torch.scaled_mm otherwise. + It needs to be a class instead of a method so that config can be read + in the __init__ method, as reading config is not allowed inside forward. + """ + + def __init__( + self, + cutlass_fp8_supported: bool = cutlass_fp8_supported(), + use_per_token_if_dynamic: bool = False, + pad_output: Optional[bool] = None, + ): + self.cutlass_fp8_supported = cutlass_fp8_supported + self.use_per_token_if_dynamic = use_per_token_if_dynamic + + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + if pad_output is None: + enable_torch_compile = os.environ.get( + "SGLANG_ENABLE_TORCH_COMPILE", "0" + ).lower() in ("1", "true", "yes") + pad_output = not enable_torch_compile + self.output_padding = 17 if pad_output else None + + def apply( + self, + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + # TODO(luka) remove this parameter in favor of __init__ + use_per_token_if_dynamic: Optional[bool] = None, + ) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + + # TODO(luka) this is here because currently MLA only decides this + # during the forward method instead of in __init__. + if use_per_token_if_dynamic is None: + use_per_token_if_dynamic = self.use_per_token_if_dynamic + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + # for sgl-kernel fp8_scaled_mm, it support per channel W now + if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]: + if _is_cuda: + qinput, x_scale = sgl_scaled_fp8_quant( + input_2d, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) + else: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) + + # Fused GEMM_DQ + if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: + # Fall back to vllm cutlass w8a8 fp8 kernel + output = ops.cutlass_scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + else: + assert ( + weight_scale.numel() == weight.shape[1] + ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" + output = fp8_scaled_mm( + qinput, + weight, + x_scale, + weight_scale, + out_dtype=input.dtype, + bias=bias, + ) + return output.view(*output_shape) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + else: + # Maybe apply padding to output, see comment in __init__ + if _is_cuda: + qinput, x_scale = sgl_scaled_fp8_quant( + input_2d, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) + if self.output_padding: + pad_size = max(self.output_padding - qinput.shape[0], 0) + if pad_size > 0: + qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size)) + else: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) + + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + elif ( + use_per_token_if_dynamic + and not per_tensor_weights + and not per_tensor_activations + and USE_ROWWISE_TORCH_SCALED_MM + ): + # For now validated on ROCm platform + # fp8 rowwise scaling in torch._scaled_mm is introduced in + # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt + # and ROCm 6.3, which only exists in torch 2.7 and above. + # For CUDA platform please validate if the + # torch._scaled_mm support rowwise scaled GEMM + # Fused GEMM_DQ Rowwise GEMM + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.t(), + bias=bias, + ) + + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + output = output.view(*output_shape) + return output + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 567e97fcb..c6524af49 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -15,6 +15,11 @@ else: from vllm import _custom_ops as vllm_ops +def is_fp8_fnuz() -> bool: + # only device 0 is checked, this assumes MI300 platforms are homogeneous + return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName + + def is_layer_skipped( prefix: str, ignored_layers: List[str], @@ -120,3 +125,29 @@ def requantize_with_max_scale( start = end return max_w_scale, weight + + +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter( + mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter] +) -> None: + + old = getattr(mod, name) + if ( + type(old) is type(new) + and old.dtype == new.dtype + and old.untyped_storage().nbytes() == new.untyped_storage().nbytes() + ): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py new file mode 100644 index 000000000..226256ed2 --- /dev/null +++ b/python/sglang/srt/managers/expert_distribution.py @@ -0,0 +1,81 @@ +import json +import logging +import time +from collections import defaultdict +from typing import Dict, List, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +# global expert distribution recording +class ExpertDistributionRecorder: + # This class is a singleton class + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls) + return cls.instance + + def __init__(self): + # the length of the dictionary is the number of layers + # the length of the list is the number of tokens + # the length of the tuple is topk's k value + self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( + list + ) + self._record = False + self._current_layer_id = "UNKNOWN" + + def set_current_layer(self, layer_idx): + self._current_layer_id = layer_idx + + def record_new_token(self, topk_ids): + if not self._record: + return + topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() + torch.cuda.synchronize() + for i in topk_ids_list: + self._expert_distribution_record[self._current_layer_id].append(tuple(i)) + + def reset(self): + """Reset the expert distribution recorder.""" + logger.info("Resetting expert distribution record...") + self._record = False + self._expert_distribution_record.clear() + self._current_layer_id = "UNKNOWN" + + def start_record(self): + """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" + if self._record == True: + logger.warning( + "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" + ) + self.reset() + self._record = True + + def stop_record(self): + """Stop recording the expert distribution. Set the recording flag to False.""" + if self._record == False: + logger.warning( + "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" + ) + self._record = False + + def dump_record(self): + """Dump the expert distribution record to a file. Reset the recorder after dumping.""" + results = {} + for layer_idx, layer_record in self._expert_distribution_record.items(): + results[layer_idx] = defaultdict(int) + for token_record in layer_record: + for expert_idx in token_record: + results[layer_idx][expert_idx] += 1 + with open( + f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv", + "w", + ) as fd: + fd.write("layer_id,expert_id,count\n") + for layer_idx, layer_results in results.items(): + for expert_idx, count in layer_results.items(): + fd.write(f"{layer_idx},{expert_idx},{count}\n") + self.reset() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 28d875015..7b5dc4520 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -53,6 +53,7 @@ from sglang.srt.disaggregation.utils import ( from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -106,7 +107,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import ( from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient -from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length +from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 2730075ff..8bdc4929f 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -47,75 +47,3 @@ def validate_input_length( return error_msg return None - - -# global expert distribution recording -class ExpertDistributionRecorder: - # This class is a singleton class - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls) - return cls.instance - - def __init__(self): - # the length of the dictionary is the number of layers - # the length of the list is the number of tokens - # the length of the tuple is topk's k value - self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( - list - ) - self._record = False - self._current_layer_id = "UNKNOWN" - - def set_current_layer(self, layer_idx): - self._current_layer_id = layer_idx - - def record_new_token(self, topk_ids): - if not self._record: - return - topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() - torch.cuda.synchronize() - for i in topk_ids_list: - self._expert_distribution_record[self._current_layer_id].append(tuple(i)) - - def reset(self): - """Reset the expert distribution recorder.""" - logger.info("Resetting expert distribution record...") - self._record = False - self._expert_distribution_record.clear() - self._current_layer_id = "UNKNOWN" - - def start_record(self): - """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" - if self._record == True: - logger.warning( - "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" - ) - self.reset() - self._record = True - - def stop_record(self): - """Stop recording the expert distribution. Set the recording flag to False.""" - if self._record == False: - logger.warning( - "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" - ) - self._record = False - - def dump_record(self): - """Dump the expert distribution record to a file. Reset the recorder after dumping.""" - results = {} - for layer_idx, layer_record in self._expert_distribution_record.items(): - results[layer_idx] = defaultdict(int) - for token_record in layer_record: - for expert_idx in token_record: - results[layer_idx][expert_idx] += 1 - with open( - f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv", - "w", - ) as fd: - fd.write("layer_id,expert_id,count\n") - for layer_idx, layer_results in results.items(): - for expert_idx, count in layer_results.items(): - fd.write(f"{layer_idx},{expert_idx},{count}\n") - self.reset() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0555e0cd2..8ff4c4373 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -67,8 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.managers.utils import ExpertDistributionRecorder from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 7d40bdbf9..fa00b35e1 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -44,7 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.utils import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 241682d4e..59b6e73d4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -16,6 +16,7 @@ import argparse import dataclasses import logging +import os import random import tempfile from typing import List, Optional @@ -341,6 +342,10 @@ class ServerArgs: self.disable_overlap_schedule = True logger.warning("Overlap scheduler is disabled for decode server") + os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( + "1" if self.enable_torch_compile else "0" + ) + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index ca79ebcb6..20f559076 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -29,3 +29,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12 pip install timm pip install sgl-kernel==0.0.5.post3 --force-reinstall + +pip uninstall vllm -y || true diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 43bf46c03..daae6f095 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -23,16 +23,12 @@ suites = { TestFile("models/test_reward_models.py", 83), TestFile("models/test_gme_qwen_models.py", 45), TestFile("test_abort.py", 51), - TestFile("test_awq.py"), TestFile("test_block_int8.py", 22), TestFile("test_chunked_prefill.py", 336), TestFile("test_eagle_infer.py", 447), TestFile("test_ebnf_constrained.py"), TestFile("test_fp8_kernel.py", 2), TestFile("test_embedding_openai_server.py", 36), - TestFile("test_expert_distribution.py", 31), - TestFile("test_gguf.py", 78), - TestFile("test_gptqmodel_dynamic.py", 72), TestFile("test_hidden_states.py", 55), TestFile("test_int8_kernel.py", 1), TestFile("test_input_embeddings.py", 38), @@ -82,6 +78,12 @@ suites = { "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), ], + "vllm_dependency_test": [ + TestFile("test_vllm_dependency.py"), + TestFile("test_awq.py"), + TestFile("test_gguf.py", 78), + TestFile("test_gptqmodel_dynamic.py", 72), + ], } diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 5caa076b1..600b60228 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -37,9 +37,6 @@ MODEL_SCORE_THRESHOLDS = { "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.65, "neuralmagic/Qwen2-72B-Instruct-FP8": 0.94, "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82, - "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84, - "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83, - "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.62, } @@ -138,7 +135,6 @@ class TestNightlyGsm8KEval(CustomTestCase): (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True), (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False), (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True), - (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1), False, False), ] cls.base_url = DEFAULT_URL_FOR_TEST diff --git a/test/srt/test_vllm_dependency.py b/test/srt/test_vllm_dependency.py new file mode 100644 index 000000000..168f681ac --- /dev/null +++ b/test/srt/test_vllm_dependency.py @@ -0,0 +1,168 @@ +import json +import os +import unittest +import warnings +from datetime import datetime +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + is_in_ci, + popen_launch_server, + write_github_step_summary, +) + +MODEL_SCORE_THRESHOLDS = { + "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84, + "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83, + "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.62, +} + + +def parse_models(model_string): + return [model.strip() for model in model_string.split(",") if model.strip()] + + +def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2): + other_args = ["--log-level-http", "warning", "--trust-remote-code"] + if is_fp8: + if "Llama-3" in model or "gemma-2" in model: + other_args.extend(["--kv-cache-dtype", "fp8_e5m2"]) + elif "Qwen2-72B-Instruct-FP8" in model: + other_args.extend(["--quantization", "fp8"]) + elif "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" in model: + other_args.extend([]) + else: + other_args.extend(["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"]) + if is_tp2: + other_args.extend(["--tp", "2"]) + if "DeepSeek" in model: + other_args.extend(["--mem-frac", "0.85"]) + if "AWQ" in model: + other_args.extend(["--quantization", "awq"]) + elif "GPTQ" in model: + other_args.extend(["--quantization", "gptq"]) + + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + return process + + +def write_results_to_json(model, metrics, mode="a"): + result = { + "timestamp": datetime.now().isoformat(), + "model": model, + "metrics": metrics, + "score": metrics["score"], + } + + existing_results = [] + if mode == "a" and os.path.exists("results.json"): + try: + with open("results.json", "r") as f: + existing_results = json.load(f) + except json.JSONDecodeError: + existing_results = [] + + if isinstance(existing_results, list): + existing_results.append(result) + else: + existing_results = [result] + + with open("results.json", "w") as f: + json.dump(existing_results, f, indent=2) + + +def check_model_scores(results): + failed_models = [] + summary = " | model | score | threshold |\n" + summary += "| ----- | ----- | --------- |\n" + + for model, score in results: + threshold = MODEL_SCORE_THRESHOLDS.get(model) + if threshold is None: + print(f"Warning: No threshold defined for model {model}") + continue + + if score < threshold: + failed_models.append( + f"\nScore Check Failed: {model}\n" + f"Model {model} score ({score:.4f}) is below threshold ({threshold:.4f})" + ) + + line = f"| {model} | {score} | {threshold} |\n" + summary += line + + print(summary) + + if is_in_ci(): + write_github_step_summary( + f"### TestNightlyGsm8KEval for vLLM awq, gptq, gguf\n{summary}" + ) + + if failed_models: + raise AssertionError("\n".join(failed_models)) + + +class TestNightlyGsm8KEval(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_groups = [ + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1), False, False), + ] + cls.base_url = DEFAULT_URL_FOR_TEST + + def test_mgsm_en_all_models(self): + warnings.filterwarnings( + "ignore", category=ResourceWarning, message="unclosed.*socket" + ) + is_first = True + all_results = [] + + for model_group, is_fp8, is_tp2 in self.model_groups: + for model in model_group: + with self.subTest(model=model): + process = popen_launch_server_wrapper( + self.base_url, model, is_fp8, is_tp2 + ) + + args = SimpleNamespace( + base_url=self.base_url, + model=model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + print( + f"{'=' * 42}\n{model} - metrics={metrics} score={metrics['score']}\n{'=' * 42}\n" + ) + + write_results_to_json(model, metrics, "w" if is_first else "a") + is_first = False + + all_results.append((model, metrics["score"])) + kill_process_tree(process.pid) + + try: + with open("results.json", "r") as f: + print("\nFinal Results from results.json:") + print(json.dumps(json.load(f), indent=2)) + except Exception as e: + print(f"Error reading results.json: {e}") + + # Check all scores after collecting all results + check_model_scores(all_results) + + +if __name__ == "__main__": + unittest.main()