Support compressed tensors fp8w8a8 (#4743)
This commit is contained in:
45
.github/workflows/vllm-dependency-test.yml
vendored
Normal file
45
.github/workflows/vllm-dependency-test.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
name: VLLM Dependency Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- "python/pyproject.toml"
|
||||
- "python/sglang/**"
|
||||
- "test/**"
|
||||
- "docs/**"
|
||||
- "scripts/**"
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- "python/pyproject.toml"
|
||||
- "python/sglang/**"
|
||||
- "test/**"
|
||||
- "docs/**"
|
||||
- "scripts/**"
|
||||
|
||||
concurrency:
|
||||
group: vllm-dependency-test-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
vllm-dependency-test:
|
||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||
github.event.pull_request.draft == false
|
||||
runs-on: 1-gpu-runner
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
env:
|
||||
FLASHINFER_REPO: 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python'
|
||||
run: |
|
||||
bash scripts/ci_install_dependency.sh
|
||||
pip install "vllm>=0.6.4.post1,<=0.7.2"
|
||||
|
||||
- name: Run VLLM dependency tests
|
||||
timeout-minutes: 60
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 run_suite.py --suite vllm_dependency_test --timeout-per-file 3600
|
||||
@@ -47,7 +47,6 @@ srt = [
|
||||
"sgl-kernel==0.0.5.post3",
|
||||
"flashinfer_python==0.2.3",
|
||||
"torch==2.5.1",
|
||||
"vllm>=0.6.4.post1,<=0.7.2",
|
||||
"cuda-python",
|
||||
"outlines>=0.0.44,<=0.1.11",
|
||||
]
|
||||
|
||||
@@ -22,7 +22,11 @@ import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.layers.quantization import (
|
||||
BASE_QUANTIZATION_METHODS,
|
||||
QUANTIZATION_METHODS,
|
||||
VLLM_AVAILABLE,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -235,7 +239,12 @@ class ModelConfig:
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
# Select supported quantization methods based on vllm availability
|
||||
if VLLM_AVAILABLE:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
else:
|
||||
supported_quantization = [*BASE_QUANTIZATION_METHODS]
|
||||
|
||||
rocm_supported_quantization = [
|
||||
"awq",
|
||||
"gptq",
|
||||
@@ -273,7 +282,11 @@ class ModelConfig:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
# Only iterate through currently available quantization methods
|
||||
available_methods = (
|
||||
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
|
||||
)
|
||||
for _, method in available_methods.items():
|
||||
quantization_override = method.override_quantization_method(
|
||||
quant_cfg, self.quantization
|
||||
)
|
||||
|
||||
@@ -1316,7 +1316,10 @@ vllm_get_world_group = None
|
||||
|
||||
|
||||
def monkey_patch_vllm_parallel_state(reverse: bool = False):
|
||||
import vllm.distributed.parallel_state as vllm_parrlel_state
|
||||
try:
|
||||
import vllm.distributed.parallel_state as vllm_parrlel_state
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group
|
||||
if vllm_get_pp_group is None:
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import (
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
||||
|
||||
if isinstance(param, _ColumnvLLMParameter):
|
||||
param.load_column_parallel_weight(
|
||||
loaded_weight,
|
||||
@@ -1247,7 +1246,7 @@ class RowParallelLinear(LinearBase):
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
if isinstance(param, BasevLLMParameter):
|
||||
if isinstance(param, RowvLLMParameter):
|
||||
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
|
||||
# It supports additional parameters like tp_rank and use_presharded_weights.
|
||||
param.load_row_parallel_weight(
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
|
||||
@@ -69,6 +68,8 @@ def moe_forward_native(
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -305,6 +305,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
self.inplace = inplace
|
||||
self.no_combine = no_combine
|
||||
self.local_num_experts = num_experts
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
@@ -629,8 +630,6 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
activation=self.activation,
|
||||
inplace=self.inplace,
|
||||
no_combine=self.no_combine,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
|
||||
@@ -17,11 +17,12 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
@@ -53,10 +54,10 @@ def fused_topk(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
if _is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -70,7 +71,7 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if _is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@@ -78,7 +79,7 @@ def fused_topk(
|
||||
gating_output.float(),
|
||||
)
|
||||
else:
|
||||
ops.topk_softmax(
|
||||
vllm_ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
|
||||
@@ -12,9 +12,6 @@ try:
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
@@ -26,6 +23,8 @@ try:
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
@@ -44,8 +43,10 @@ except ImportError:
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
@@ -55,10 +56,9 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"fp8": Fp8Config,
|
||||
"blockwise_int8": BlockInt8Config,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"w8a8_int8": W8A8Int8Config,
|
||||
"w8a8_fp8": W8A8Fp8Config,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
}
|
||||
|
||||
# Add vllm-dependent methods if available
|
||||
@@ -74,10 +74,11 @@ if VLLM_AVAILABLE:
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
}
|
||||
QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
|
||||
|
||||
|
||||
@@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC):
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# mapping is updated by models as they initialize
|
||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
# quantization compressed_tensors module
|
||||
|
||||
To support compressed_tensors format quantization models, we adapted https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors into SGLang.
|
||||
|
||||
|
||||
For practical purposes, we have only applied the compressed_tensors format of `w8a8_fp8`. If you have requirements for other formats, you can submit an issue through this [link](https://github.com/sgl-project/sglang/issues).
|
||||
@@ -0,0 +1,652 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (
|
||||
CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure,
|
||||
)
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target,
|
||||
is_activation_quantization_format,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
|
||||
|
||||
|
||||
class DeviceCapability(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
|
||||
def as_version_str(self) -> str:
|
||||
return f"{self.major}.{self.minor}"
|
||||
|
||||
def to_int(self) -> int:
|
||||
"""
|
||||
Express device capability as an integer ``<major><minor>``.
|
||||
|
||||
It is assumed that the minor version is always a single digit.
|
||||
"""
|
||||
assert 0 <= self.minor < 10
|
||||
return self.major * 10 + self.minor
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_scheme_map: Dict[str, Any],
|
||||
ignore: List[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: List[str],
|
||||
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "compressed_tensors"
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
# TODO (@robertgshaw2): support module names
|
||||
if should_ignore_layer(
|
||||
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
if scheme is None:
|
||||
return UnquantizedLinearMethod()
|
||||
layer.scheme = scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config
|
||||
)
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _parse_sparsity_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A tuple with two elements
|
||||
1. A dictionary mapping target layer names to their corresponding
|
||||
sparsity_config
|
||||
2. A list of layer names to ignore for sparsity
|
||||
"""
|
||||
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
|
||||
return dict(), []
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config)
|
||||
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
|
||||
target: sparsity_config for target in sparsity_config.targets or list()
|
||||
}
|
||||
sparsity_ignore_list = sparsity_config.ignore or list()
|
||||
return sparse_scheme_map, sparsity_ignore_list
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
target_scheme_map: Dict[str, Any] = dict()
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
# an input_activations key with details about how the activations are
|
||||
# quantized, a weights key indicating how the weights are quantized,
|
||||
# and a list of targets under the `targets` key, dictating which
|
||||
# layers are impacted by the quantization details. The quantization
|
||||
# details follow the structure defined by the QuantizationArgs
|
||||
# pydantic model, which is used to verify the structure of the
|
||||
# quant_config and also store the details for later use.
|
||||
|
||||
config_groups = config.get("config_groups", dict())
|
||||
for _, quant_config in config_groups.items():
|
||||
targets = quant_config.get("targets")
|
||||
for target in targets:
|
||||
target_scheme_map[target] = {}
|
||||
target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(
|
||||
quant_config.get("weights")
|
||||
)
|
||||
|
||||
target_scheme_map[target]["input_activations"] = None
|
||||
if is_activation_quantization_format(quant_format):
|
||||
input_activations = quant_config.get("input_activations")
|
||||
# The only case where we have activation quant supported
|
||||
# but no input_activations provided in the config
|
||||
# should be w8a16fp8 w8a16fp8 can also run for cases where
|
||||
# there is an input_quant but it is ignored
|
||||
if not input_activations:
|
||||
assert (
|
||||
target_scheme_map[target]["weights"].type
|
||||
== QuantizationType.FLOAT
|
||||
)
|
||||
else:
|
||||
target_scheme_map[target]["input_activations"] = (
|
||||
QuantizationArgs.model_validate( # noqa: E501
|
||||
quant_config.get("input_activations")
|
||||
)
|
||||
)
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
|
||||
capability_tuple = DeviceCapability(*torch.cuda.get_device_capability())
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.",
|
||||
)
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_static_tensor_w8a8(
|
||||
self, weight_quant: BaseModel, input_quant: BaseModel
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
)
|
||||
is_tensor = (
|
||||
weight_strategy
|
||||
and input_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
)
|
||||
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
|
||||
|
||||
def _is_dynamic_token_w8a8(
|
||||
self, weight_quant: BaseModel, input_quant: BaseModel
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
)
|
||||
is_token = (
|
||||
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
|
||||
)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
|
||||
# Confirm weights and activations quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_floating_point = (
|
||||
weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT
|
||||
)
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_or_channel_weight = weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR,
|
||||
QuantizationStrategy.CHANNEL,
|
||||
]
|
||||
if not (
|
||||
is_floating_point
|
||||
and is_symmetric_weight
|
||||
and is_static_weight
|
||||
and is_per_tensor_or_channel_weight
|
||||
):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.dynamic:
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_symmetric_activation = input_quant.symmetric
|
||||
is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
return is_symmetric_activation and is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm we have floating points.
|
||||
if weight_quant.type != QuantizationType.FLOAT:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_or_channel_weight = weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR,
|
||||
QuantizationStrategy.CHANNEL,
|
||||
]
|
||||
if not (
|
||||
is_symmetric_weight
|
||||
and is_static_weight # noqa: SIM103
|
||||
and is_per_tensor_or_channel_weight
|
||||
):
|
||||
return False
|
||||
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
|
||||
def _is_wNa16_group_channel(
|
||||
self, weight_quant: BaseModel, input_quant: BaseModel
|
||||
) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
is_symmetric = weight_quant.symmetric
|
||||
is_channel_group = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_static = not weight_quant.dynamic
|
||||
|
||||
return is_channel_group and input_quant_none and is_symmetric and is_static
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self, weight_quant: BaseModel, input_quant: BaseModel
|
||||
) -> "CompressedTensorsScheme":
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
|
||||
)
|
||||
if (
|
||||
self.quant_format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
|
||||
):
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size,
|
||||
)
|
||||
if (
|
||||
self.quant_format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
||||
):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
)
|
||||
|
||||
if is_activation_quantization_format(self.quant_format):
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
|
||||
)
|
||||
if is_fp8_w8a8_supported:
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=(
|
||||
input_quant and not input_quant.dynamic
|
||||
),
|
||||
)
|
||||
else:
|
||||
# note: input_quant will be present for converted models;
|
||||
# will be ignored during inference post loading
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=not input_quant.dynamic,
|
||||
)
|
||||
|
||||
# note: input_quant can be None
|
||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
|
||||
)
|
||||
is_static_input_scheme = input_quant and not input_quant.dynamic
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
)
|
||||
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
|
||||
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
|
||||
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
have a quantization scheme. Each config_group has a list of targets
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
an nn.Module name.
|
||||
|
||||
Detect whether a layer_name is found in any target and
|
||||
use the quantization scheme corresponding to the matched target
|
||||
to select the CompressedTensorsScheme used for infernece.
|
||||
"""
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||
# so we do not have to re-write these functions
|
||||
# need to make accelerate optional in ct to do this
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
)
|
||||
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
|
||||
# Find the sparsity scheme of the layer
|
||||
# assume that fused layers inerhit first component's sparsity scheme
|
||||
sparsity_targets = self.sparsity_scheme_map.keys() - set(
|
||||
self.sparsity_ignore_list
|
||||
)
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
||||
with suppress(ValueError):
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=sparsity_targets,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
)
|
||||
sparsity_scheme = self.sparsity_scheme_map[matched_target]
|
||||
|
||||
if self.supports_cutlass_24(
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme,
|
||||
):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensors24, please install vllm"
|
||||
)
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
model_compression_config = (
|
||||
None
|
||||
if sparsity_scheme is None or sparsity_scheme.format == "dense"
|
||||
else self.config
|
||||
)
|
||||
|
||||
scheme = CompressedTensors24(
|
||||
quantized=weight_quant is not None or input_quant is not None,
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
model_compression_config=model_compression_config,
|
||||
)
|
||||
elif weight_quant is None:
|
||||
logger.warning_once(
|
||||
"Acceleration for non-quantized schemes is "
|
||||
"not supported by Compressed Tensors. "
|
||||
"Falling back to UnquantizedLinearMethod"
|
||||
)
|
||||
return None
|
||||
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
)
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: Optional[QuantizationArgs],
|
||||
input_quant: Optional[QuantizationArgs],
|
||||
sparsity_scheme: Optional[SparsityCompressionConfig] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the layer is supported by the Cutlass 2:4 Kernel
|
||||
Conditions:
|
||||
- Overarching condition: Sparsity Structure is 2:4
|
||||
- Unquantized cases are supported
|
||||
- Weight only quantization is not-supported
|
||||
- Supported weight quantization strategies are TENSOR and CHANNEL
|
||||
- Supported input quantization strategies are TENSOR and TOKEN
|
||||
- Only 8 bit quantization is supported
|
||||
|
||||
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
||||
False otherwise
|
||||
"""
|
||||
if sparsity_scheme is None:
|
||||
return False
|
||||
|
||||
is_valid_sparsity_structure: bool = (
|
||||
sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value
|
||||
)
|
||||
|
||||
valid_compressors = {
|
||||
CompressionFormat.dense.value,
|
||||
CompressionFormat.sparse_24_bitmask.value,
|
||||
}
|
||||
|
||||
is_valid_sparsity = (
|
||||
is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors
|
||||
)
|
||||
|
||||
if not is_valid_sparsity:
|
||||
return False
|
||||
|
||||
# Unquantized cases are supported
|
||||
if weight_quant is None and input_quant is None:
|
||||
return True
|
||||
|
||||
# Weight only quantization is not-supported
|
||||
if weight_quant is not None and input_quant is None:
|
||||
return False
|
||||
|
||||
supported_weight_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value,
|
||||
QuantizationStrategy.CHANNEL.value,
|
||||
]
|
||||
|
||||
assert weight_quant is not None
|
||||
assert input_quant is not None
|
||||
if weight_quant.strategy not in supported_weight_quant_strategies:
|
||||
return False
|
||||
|
||||
supported_input_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value,
|
||||
QuantizationStrategy.TOKEN.value,
|
||||
]
|
||||
|
||||
if input_quant.strategy not in supported_input_quant_strategies:
|
||||
return False
|
||||
|
||||
return weight_quant.num_bits == input_quant.num_bits == 8
|
||||
|
||||
|
||||
class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: CompressedTensorsConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
@@ -0,0 +1,658 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
is_cuda,
|
||||
is_fp8_fnuz,
|
||||
per_tensor_dequantize,
|
||||
replace_parameter,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
try:
|
||||
import vllm
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPTQMarlinState(Enum):
|
||||
REPACK = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsMoEMethod",
|
||||
"CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
]
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
||||
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
|
||||
)
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations"
|
||||
)
|
||||
|
||||
if not (
|
||||
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}"
|
||||
)
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
if is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
|
||||
)
|
||||
)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False
|
||||
)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
|
||||
if _is_cuda:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
else:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
||||
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.num_bits = config.num_bits
|
||||
self.packed_factor = 32 // config.num_bits
|
||||
self.strategy = config.strategy
|
||||
self.group_size = config.group_size
|
||||
self.actorder = config.actorder
|
||||
assert config.symmetric, "Only symmetric quantization is supported for MoE"
|
||||
|
||||
if not (
|
||||
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
|
||||
and self.num_bits in WNA16_SUPPORTED_BITS
|
||||
):
|
||||
raise ValueError(
|
||||
"For Fused MoE layers, only ",
|
||||
f"{CompressionFormat.pack_quantized.value} ",
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}",
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
assert (
|
||||
params_dtype == torch.float16
|
||||
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
# shard for TP along the transposed dims
|
||||
extra_weight_attrs.update(
|
||||
{"is_transposed": True, "quant_method": self.strategy}
|
||||
)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size // self.packed_factor,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition // self.packed_factor,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# In the case where we have actorder/g_idx,
|
||||
# we do not partition the w2 scales
|
||||
load_full_w2 = self.actorder and self.group_size != -1
|
||||
w2_scales_size = (
|
||||
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
|
||||
)
|
||||
|
||||
self.is_k_full = (not self.actorder) or (
|
||||
intermediate_size_per_partition == intermediate_size_full
|
||||
)
|
||||
|
||||
if self.strategy == "channel":
|
||||
num_groups_w2 = num_groups_w13 = 1
|
||||
self.group_size = -1
|
||||
else:
|
||||
num_groups_w2 = w2_scales_size // self.group_size
|
||||
num_groups_w13 = hidden_size // self.group_size
|
||||
|
||||
w13_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
num_groups_w13,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_scale)
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
|
||||
w2_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_scale)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
|
||||
|
||||
w2_weight_shape = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
||||
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
||||
w13_weight_shape = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2), requires_grad=False
|
||||
)
|
||||
|
||||
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
||||
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
||||
|
||||
w13_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
|
||||
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
||||
|
||||
w2_g_idx = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
|
||||
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
||||
|
||||
w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
def replace_tensor(name, new_t):
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
def get_scale_perms(num_bits: int):
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
|
||||
)
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
def marlin_permute_scales(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
||||
):
|
||||
scale_perm, scale_perm_single = get_scale_perms(num_bits)
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
return s
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
||||
):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty(
|
||||
(num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(
|
||||
s[e], size_k, size_n, group_size, num_bits
|
||||
)
|
||||
return output
|
||||
|
||||
size_k2 = layer.w2_weight_packed.shape[2]
|
||||
size_k13 = layer.w13_weight_packed.shape[2]
|
||||
|
||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||
device = layer.w13_weight_g_idx.device
|
||||
|
||||
# when running models with grouped act order,
|
||||
# resort to g_idx values provided in checkpoint
|
||||
if self.actorder == "group":
|
||||
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
|
||||
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
|
||||
w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
|
||||
w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
|
||||
|
||||
for e in range(num_experts):
|
||||
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to(
|
||||
torch.int32
|
||||
)
|
||||
w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
|
||||
w13_g_idx_sort_indices[e]
|
||||
]
|
||||
w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]]
|
||||
|
||||
replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
|
||||
replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
|
||||
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
||||
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
||||
|
||||
else:
|
||||
layer.w13_weight_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_weight_g_idx = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
||||
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w13_weight_packed,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w13_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w2_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w2_weight_packed", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
layer.w13_weight_scale,
|
||||
size_k13,
|
||||
layer.w13_weight_scale.shape[2],
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_scale", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight_scale.shape[1]
|
||||
* (self.group_size if self.group_size != -1 else self.packed_factor),
|
||||
size_k2,
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use fused_marlin_moe, please install vllm"
|
||||
)
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
g_idx2=layer.w2_weight_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.num_bits,
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
]
|
||||
@@ -0,0 +1,56 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["CompressedTensorsScheme"]
|
||||
|
||||
|
||||
class CompressedTensorsScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by CompressedTensors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,162 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
if is_fp8_fnuz():
|
||||
input_scale = getattr(layer, "input_scale", None)
|
||||
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=max_w_scale, input_scale=input_scale
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight = layer.weight
|
||||
|
||||
if is_fp8_fnuz():
|
||||
input_scale = getattr(layer, "input_scale", None)
|
||||
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,218 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from types import MappingProxyType
|
||||
from typing import Iterable, List, Mapping, Optional
|
||||
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping and layer_name not in ignore:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore
|
||||
)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(
|
||||
f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme."
|
||||
)
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(
|
||||
layer_name=layer_name, targets=ignore
|
||||
)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def find_matched_target(
|
||||
layer_name: Optional[str],
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
config_groups, where each layer can be quantized with with a different
|
||||
scheme.
|
||||
|
||||
targets in each config_group will be a list of either layer names
|
||||
(or regexes corresponding to layer names) or names of torch Modules.
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
Second, we try to match the module's name with a target
|
||||
Third, we try to map the layer_name to a list of fused module names.
|
||||
*All* component module names must match in order for a match to be
|
||||
successful. A successful match returns the first component target
|
||||
|
||||
:param layer_name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
:param fused_strategy: either "all" or "any". If using "all", fused
|
||||
layers match if "all" of its components match
|
||||
"""
|
||||
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
|
||||
matched_target = (
|
||||
_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets, True)
|
||||
or _match_fused_layer(layer_name, targets, fused_mapping)
|
||||
)
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(
|
||||
f"Unable to find matching target for {layer_name} in the "
|
||||
"compressed-tensors config."
|
||||
)
|
||||
|
||||
return matched_target
|
||||
|
||||
|
||||
def _find_first_match(
|
||||
value: str, targets: Iterable[str], check_contains: bool = False
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns first element of target that matches value either
|
||||
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
|
||||
:param value: string to compare the list of targets against
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(value, target, check_contains=check_contains):
|
||||
return target
|
||||
return None
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(
|
||||
value: str, target: str, check_contains: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _match_fused_layer(
|
||||
layer_name: str,
|
||||
target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
|
||||
Implements an "all" matching strategy where a fused layer matches iff
|
||||
"all" of its components match
|
||||
|
||||
:param layer_name: layer name
|
||||
:param target_layers: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
|
||||
Examples:
|
||||
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||
target_layers = ["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.k_proj",
|
||||
"model.layers.0.self_attn.v_proj"]
|
||||
"""
|
||||
# find layer_name in mapping
|
||||
fused = next((key for key in fused_mapping if layer_name.endswith(key)), None)
|
||||
if fused is None:
|
||||
return None
|
||||
|
||||
# expand path of unfused components
|
||||
unfused_paths = [
|
||||
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: List[Optional[str]] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
unfused_matches.append(target)
|
||||
break
|
||||
else:
|
||||
unfused_matches.append(None)
|
||||
|
||||
return unfused_matches[0] if all(unfused_matches) else None
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -18,6 +19,7 @@ from sglang.srt.utils import (
|
||||
|
||||
try:
|
||||
import vllm
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -31,19 +33,29 @@ if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
||||
|
||||
if use_vllm_cutlass_w8a8_fp8_kernel and VLLM_AVAILABLE:
|
||||
from vllm import _custom_ops as ops
|
||||
else:
|
||||
from sgl_kernel import fp8_scaled_mm
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
_TORCH_VERSION = torch.__version__.split("+")[0]
|
||||
try:
|
||||
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
||||
except ValueError:
|
||||
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
||||
|
||||
# The condition to determine if it is on a platform that supports
|
||||
# torch._scaled_mm rowwise feature.
|
||||
# The condition is determined once as the operations
|
||||
# are time consuming.
|
||||
USE_ROWWISE_TORCH_SCALED_MM = (
|
||||
_is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
||||
)
|
||||
|
||||
|
||||
def cutlass_fp8_supported():
|
||||
if not _is_cuda:
|
||||
@@ -330,3 +342,223 @@ def apply_fp8_linear(
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
class Fp8LinearOp:
|
||||
"""
|
||||
This class executes a FP8 linear layer using cutlass if supported and
|
||||
torch.scaled_mm otherwise.
|
||||
It needs to be a class instead of a method so that config can be read
|
||||
in the __init__ method, as reading config is not allowed inside forward.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
pad_output: Optional[bool] = None,
|
||||
):
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported
|
||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
if pad_output is None:
|
||||
enable_torch_compile = os.environ.get(
|
||||
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
||||
).lower() in ("1", "true", "yes")
|
||||
pad_output = not enable_torch_compile
|
||||
self.output_padding = 17 if pad_output else None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
input_scale_ub: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
# TODO(luka) remove this parameter in favor of __init__
|
||||
use_per_token_if_dynamic: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
# TODO(luka) this is here because currently MLA only decides this
|
||||
# during the forward method instead of in __init__.
|
||||
if use_per_token_if_dynamic is None:
|
||||
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
||||
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
||||
if _is_cuda:
|
||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
scale_ub=input_scale_ub,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
weight_scale.numel() == weight.shape[1]
|
||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||
output = fp8_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
else:
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
if _is_cuda:
|
||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
if self.output_padding:
|
||||
pad_size = max(self.output_padding - qinput.shape[0], 0)
|
||||
if pad_size > 0:
|
||||
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = x_scale.numel() == 1
|
||||
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
# Fused GEMM_DQ
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
elif (
|
||||
use_per_token_if_dynamic
|
||||
and not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM
|
||||
):
|
||||
# For now validated on ROCm platform
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
||||
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
# For CUDA platform please validate if the
|
||||
# torch._scaled_mm support rowwise scaled GEMM
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.t(),
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
output = output.view(*output_shape)
|
||||
return output
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
# due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * weight_scale.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
@@ -15,6 +15,11 @@ else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
|
||||
def is_fp8_fnuz() -> bool:
|
||||
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
|
||||
def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: List[str],
|
||||
@@ -120,3 +125,29 @@ def requantize_with_max_scale(
|
||||
start = end
|
||||
|
||||
return max_w_scale, weight
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_parameter(
|
||||
mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
|
||||
) -> None:
|
||||
|
||||
old = getattr(mod, name)
|
||||
if (
|
||||
type(old) is type(new)
|
||||
and old.dtype == new.dtype
|
||||
and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
|
||||
):
|
||||
# If we can just update in-place to avoid re-registering
|
||||
# can be faster if the underlying storage is the same
|
||||
update_tensor_inplace(old, new)
|
||||
else:
|
||||
# Fallback re-register parameter, convert to Parameter if necessary
|
||||
# this not only ensures we don't register a tensor as a parameter, but
|
||||
# also ensures that all parameter subclasses get re-registered as
|
||||
# parameters for `torch.compile` compatibility
|
||||
if not isinstance(new, torch.nn.Parameter):
|
||||
new = torch.nn.Parameter(new, requires_grad=False)
|
||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||
|
||||
81
python/sglang/srt/managers/expert_distribution.py
Normal file
81
python/sglang/srt/managers/expert_distribution.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# global expert distribution recording
|
||||
class ExpertDistributionRecorder:
|
||||
# This class is a singleton class
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
# the length of the dictionary is the number of layers
|
||||
# the length of the list is the number of tokens
|
||||
# the length of the tuple is topk's k value
|
||||
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
||||
list
|
||||
)
|
||||
self._record = False
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def set_current_layer(self, layer_idx):
|
||||
self._current_layer_id = layer_idx
|
||||
|
||||
def record_new_token(self, topk_ids):
|
||||
if not self._record:
|
||||
return
|
||||
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
||||
torch.cuda.synchronize()
|
||||
for i in topk_ids_list:
|
||||
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
||||
|
||||
def reset(self):
|
||||
"""Reset the expert distribution recorder."""
|
||||
logger.info("Resetting expert distribution record...")
|
||||
self._record = False
|
||||
self._expert_distribution_record.clear()
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def start_record(self):
|
||||
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
||||
if self._record == True:
|
||||
logger.warning(
|
||||
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
||||
)
|
||||
self.reset()
|
||||
self._record = True
|
||||
|
||||
def stop_record(self):
|
||||
"""Stop recording the expert distribution. Set the recording flag to False."""
|
||||
if self._record == False:
|
||||
logger.warning(
|
||||
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
||||
)
|
||||
self._record = False
|
||||
|
||||
def dump_record(self):
|
||||
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
||||
results = {}
|
||||
for layer_idx, layer_record in self._expert_distribution_record.items():
|
||||
results[layer_idx] = defaultdict(int)
|
||||
for token_record in layer_record:
|
||||
for expert_idx in token_record:
|
||||
results[layer_idx][expert_idx] += 1
|
||||
with open(
|
||||
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
||||
"w",
|
||||
) as fd:
|
||||
fd.write("layer_id,expert_id,count\n")
|
||||
for layer_idx, layer_results in results.items():
|
||||
for expert_idx, count in layer_results.items():
|
||||
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
||||
self.reset()
|
||||
@@ -53,6 +53,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
@@ -106,7 +107,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
|
||||
@@ -47,75 +47,3 @@ def validate_input_length(
|
||||
return error_msg
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# global expert distribution recording
|
||||
class ExpertDistributionRecorder:
|
||||
# This class is a singleton class
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
# the length of the dictionary is the number of layers
|
||||
# the length of the list is the number of tokens
|
||||
# the length of the tuple is topk's k value
|
||||
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
||||
list
|
||||
)
|
||||
self._record = False
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def set_current_layer(self, layer_idx):
|
||||
self._current_layer_id = layer_idx
|
||||
|
||||
def record_new_token(self, topk_ids):
|
||||
if not self._record:
|
||||
return
|
||||
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
||||
torch.cuda.synchronize()
|
||||
for i in topk_ids_list:
|
||||
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
||||
|
||||
def reset(self):
|
||||
"""Reset the expert distribution recorder."""
|
||||
logger.info("Resetting expert distribution record...")
|
||||
self._record = False
|
||||
self._expert_distribution_record.clear()
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def start_record(self):
|
||||
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
||||
if self._record == True:
|
||||
logger.warning(
|
||||
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
||||
)
|
||||
self.reset()
|
||||
self._record = True
|
||||
|
||||
def stop_record(self):
|
||||
"""Stop recording the expert distribution. Set the recording flag to False."""
|
||||
if self._record == False:
|
||||
logger.warning(
|
||||
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
||||
)
|
||||
self._record = False
|
||||
|
||||
def dump_record(self):
|
||||
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
||||
results = {}
|
||||
for layer_idx, layer_record in self._expert_distribution_record.items():
|
||||
results[layer_idx] = defaultdict(int)
|
||||
for token_record in layer_record:
|
||||
for expert_idx in token_record:
|
||||
results[layer_idx][expert_idx] += 1
|
||||
with open(
|
||||
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
||||
"w",
|
||||
) as fd:
|
||||
fd.write("layer_id,expert_id,count\n")
|
||||
for layer_idx, layer_results in results.items():
|
||||
for expert_idx, count in layer_results.items():
|
||||
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
||||
self.reset()
|
||||
|
||||
@@ -67,8 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||
|
||||
@@ -44,7 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
@@ -341,6 +342,10 @@ class ServerArgs:
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for decode server")
|
||||
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
|
||||
@@ -29,3 +29,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12
|
||||
pip install timm
|
||||
|
||||
pip install sgl-kernel==0.0.5.post3 --force-reinstall
|
||||
|
||||
pip uninstall vllm -y || true
|
||||
|
||||
@@ -23,16 +23,12 @@ suites = {
|
||||
TestFile("models/test_reward_models.py", 83),
|
||||
TestFile("models/test_gme_qwen_models.py", 45),
|
||||
TestFile("test_abort.py", 51),
|
||||
TestFile("test_awq.py"),
|
||||
TestFile("test_block_int8.py", 22),
|
||||
TestFile("test_chunked_prefill.py", 336),
|
||||
TestFile("test_eagle_infer.py", 447),
|
||||
TestFile("test_ebnf_constrained.py"),
|
||||
TestFile("test_fp8_kernel.py", 2),
|
||||
TestFile("test_embedding_openai_server.py", 36),
|
||||
TestFile("test_expert_distribution.py", 31),
|
||||
TestFile("test_gguf.py", 78),
|
||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
||||
TestFile("test_hidden_states.py", 55),
|
||||
TestFile("test_int8_kernel.py", 1),
|
||||
TestFile("test_input_embeddings.py", 38),
|
||||
@@ -82,6 +78,12 @@ suites = {
|
||||
"nightly": [
|
||||
TestFile("test_nightly_gsm8k_eval.py"),
|
||||
],
|
||||
"vllm_dependency_test": [
|
||||
TestFile("test_vllm_dependency.py"),
|
||||
TestFile("test_awq.py"),
|
||||
TestFile("test_gguf.py", 78),
|
||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -37,9 +37,6 @@ MODEL_SCORE_THRESHOLDS = {
|
||||
"neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.65,
|
||||
"neuralmagic/Qwen2-72B-Instruct-FP8": 0.94,
|
||||
"neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82,
|
||||
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84,
|
||||
"hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83,
|
||||
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.62,
|
||||
}
|
||||
|
||||
|
||||
@@ -138,7 +135,6 @@ class TestNightlyGsm8KEval(CustomTestCase):
|
||||
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True),
|
||||
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False),
|
||||
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True),
|
||||
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1), False, False),
|
||||
]
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
|
||||
168
test/srt/test_vllm_dependency.py
Normal file
168
test/srt/test_vllm_dependency.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
write_github_step_summary,
|
||||
)
|
||||
|
||||
MODEL_SCORE_THRESHOLDS = {
|
||||
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84,
|
||||
"hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83,
|
||||
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.62,
|
||||
}
|
||||
|
||||
|
||||
def parse_models(model_string):
|
||||
return [model.strip() for model in model_string.split(",") if model.strip()]
|
||||
|
||||
|
||||
def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2):
|
||||
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
||||
if is_fp8:
|
||||
if "Llama-3" in model or "gemma-2" in model:
|
||||
other_args.extend(["--kv-cache-dtype", "fp8_e5m2"])
|
||||
elif "Qwen2-72B-Instruct-FP8" in model:
|
||||
other_args.extend(["--quantization", "fp8"])
|
||||
elif "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" in model:
|
||||
other_args.extend([])
|
||||
else:
|
||||
other_args.extend(["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"])
|
||||
if is_tp2:
|
||||
other_args.extend(["--tp", "2"])
|
||||
if "DeepSeek" in model:
|
||||
other_args.extend(["--mem-frac", "0.85"])
|
||||
if "AWQ" in model:
|
||||
other_args.extend(["--quantization", "awq"])
|
||||
elif "GPTQ" in model:
|
||||
other_args.extend(["--quantization", "gptq"])
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
def write_results_to_json(model, metrics, mode="a"):
|
||||
result = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"metrics": metrics,
|
||||
"score": metrics["score"],
|
||||
}
|
||||
|
||||
existing_results = []
|
||||
if mode == "a" and os.path.exists("results.json"):
|
||||
try:
|
||||
with open("results.json", "r") as f:
|
||||
existing_results = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
existing_results = []
|
||||
|
||||
if isinstance(existing_results, list):
|
||||
existing_results.append(result)
|
||||
else:
|
||||
existing_results = [result]
|
||||
|
||||
with open("results.json", "w") as f:
|
||||
json.dump(existing_results, f, indent=2)
|
||||
|
||||
|
||||
def check_model_scores(results):
|
||||
failed_models = []
|
||||
summary = " | model | score | threshold |\n"
|
||||
summary += "| ----- | ----- | --------- |\n"
|
||||
|
||||
for model, score in results:
|
||||
threshold = MODEL_SCORE_THRESHOLDS.get(model)
|
||||
if threshold is None:
|
||||
print(f"Warning: No threshold defined for model {model}")
|
||||
continue
|
||||
|
||||
if score < threshold:
|
||||
failed_models.append(
|
||||
f"\nScore Check Failed: {model}\n"
|
||||
f"Model {model} score ({score:.4f}) is below threshold ({threshold:.4f})"
|
||||
)
|
||||
|
||||
line = f"| {model} | {score} | {threshold} |\n"
|
||||
summary += line
|
||||
|
||||
print(summary)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(
|
||||
f"### TestNightlyGsm8KEval for vLLM awq, gptq, gguf\n{summary}"
|
||||
)
|
||||
|
||||
if failed_models:
|
||||
raise AssertionError("\n".join(failed_models))
|
||||
|
||||
|
||||
class TestNightlyGsm8KEval(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model_groups = [
|
||||
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1), False, False),
|
||||
]
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
def test_mgsm_en_all_models(self):
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=ResourceWarning, message="unclosed.*socket"
|
||||
)
|
||||
is_first = True
|
||||
all_results = []
|
||||
|
||||
for model_group, is_fp8, is_tp2 in self.model_groups:
|
||||
for model in model_group:
|
||||
with self.subTest(model=model):
|
||||
process = popen_launch_server_wrapper(
|
||||
self.base_url, model, is_fp8, is_tp2
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=None,
|
||||
num_threads=1024,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
print(
|
||||
f"{'=' * 42}\n{model} - metrics={metrics} score={metrics['score']}\n{'=' * 42}\n"
|
||||
)
|
||||
|
||||
write_results_to_json(model, metrics, "w" if is_first else "a")
|
||||
is_first = False
|
||||
|
||||
all_results.append((model, metrics["score"]))
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
try:
|
||||
with open("results.json", "r") as f:
|
||||
print("\nFinal Results from results.json:")
|
||||
print(json.dumps(json.load(f), indent=2))
|
||||
except Exception as e:
|
||||
print(f"Error reading results.json: {e}")
|
||||
|
||||
# Check all scores after collecting all results
|
||||
check_model_scores(all_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user