[Feature] support compressed-tensors w4a16 quantization (#154)
- native int4 kimi model inference is supported Signed-off-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
@@ -149,6 +149,14 @@ By utilizing the vLLM Kunlun plugin, popular open-source models, including Trans
|
|||||||
<td class="status-support">✅</td>
|
<td class="status-support">✅</td>
|
||||||
<td></td>
|
<td></td>
|
||||||
</tr>
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="model-name">Kimi-K2</td>
|
||||||
|
<td class="status-support">✅</td>
|
||||||
|
<td class="status-support">✅</td>
|
||||||
|
<td></td>
|
||||||
|
<td class="status-support">✅</td>
|
||||||
|
<td></td>
|
||||||
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
|
|||||||
@@ -8,22 +8,23 @@ Like vLLM, we now support quantization methods such as compressed-tensors, AWQ,
|
|||||||
<table border="1" style="border-collapse: collapse; width: auto; margin: 0 0 0 0; text-align: center;">
|
<table border="1" style="border-collapse: collapse; width: auto; margin: 0 0 0 0; text-align: center;">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<td colspan="2" style="padding: 10px; font-weight: bold; border: 1px solid #000;">Compressed-Tensor (w8a8)</td>
|
<td colspan="2" style="padding: 10px; font-weight: bold; border: 1px solid #000;">Compressed-Tensors (w8a8-Int8)</td>
|
||||||
<td colspan="4" style="padding: 10px; font-weight: bold; border: 1px solid #000;">Weight only (w4a16/w8a16)</td>
|
<td colspan="4" style="padding: 10px; font-weight: bold; border: 1px solid #000;">Weight only (w4a16/w8a16)</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">Dynamic</td>
|
<td style="padding: 10px; border: 1px solid #000;">Dynamic</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">Static</td>
|
<td style="padding: 10px; border: 1px solid #000;">Static</td>
|
||||||
<td colspan="2" style="padding: 10px; border: 1px solid #000;">AWQ (w4a16)</td>
|
<td colspan="1" style="padding: 10px; border: 1px solid #000;">AWQ (w4a16)</td>
|
||||||
<td colspan="2" style="padding: 10px; border: 1px solid #000;">GPTQ (w4a16/w8a16)</td>
|
<td colspan="2" style="padding: 10px; border: 1px solid #000;">GPTQ (w4a16/w8a16)</td>
|
||||||
|
<td colspan="1" style="padding: 10px; border: 1px solid #000;">Compressed-Tensors (w4a16)</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
|
<td style="padding: 10px; border: 1px solid #000;">Dense/MoE</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">Dense/MoE</td>
|
<td style="padding: 10px; border: 1px solid #000;">Dense/MoE</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">Dense/MoE</td>
|
<td style="padding: 10px; border: 1px solid #000;">Dense/MoE</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">Dense</td>
|
<td style="padding: 10px; border: 1px solid #000;">Dense</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">MoE</td>
|
<td style="padding: 10px; border: 1px solid #000;">MoE</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">Dense</td>
|
<td style="padding: 10px; border: 1px solid #000;">Dense/MoE</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">MoE</td>
|
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
@@ -32,14 +33,16 @@ Like vLLM, we now support quantization methods such as compressed-tensors, AWQ,
|
|||||||
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
||||||
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
|
||||||
<td style="padding: 10px; border: 1px solid #000;">WIP</td>
|
<td style="padding: 10px; border: 1px solid #000;">WIP</td>
|
||||||
|
<td style="padding: 10px; border: 1px solid #000;">✅</td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
+ W8A8 dynamic and static quantization are now supported for all LLMs and VLMs.
|
+ Compressed-Tensors w8a8-Int8 dynamic and static quantization are supported for all LLMs and VLMs.
|
||||||
+ AWQ/GPTQ quantization is supported for all dense models.
|
+ Compressed-Tensors w4a16 are supported for all LLMs and VLMs.
|
||||||
|
+ AWQ(w4a16) quantization is supported for all LLMs and VLMs.
|
||||||
|
+ GPTQ (w4a16/w8a16) quantization is supported for all dense models.
|
||||||
|
|
||||||
## Usages
|
## Usages
|
||||||
|
|
||||||
|
|||||||
@@ -15,13 +15,20 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
# embedding
|
||||||
import vllm_kunlun.ops.rotary_embedding
|
import vllm_kunlun.ops.rotary_embedding
|
||||||
import vllm_kunlun.ops.layernorm
|
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||||
|
|
||||||
|
# quantization
|
||||||
import vllm_kunlun.ops.quantization.awq
|
import vllm_kunlun.ops.quantization.awq
|
||||||
import vllm_kunlun.ops.quantization.gptq
|
import vllm_kunlun.ops.quantization.gptq
|
||||||
import vllm_kunlun.ops.quantization.moe_wna16
|
import vllm_kunlun.ops.quantization.moe_wna16
|
||||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
|
||||||
|
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
|
||||||
|
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
|
||||||
|
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
|
||||||
|
|
||||||
|
# base layers
|
||||||
|
import vllm_kunlun.ops.layernorm
|
||||||
import vllm_kunlun.ops.linear
|
import vllm_kunlun.ops.linear
|
||||||
import vllm_kunlun.ops.fused_moe.layer
|
import vllm_kunlun.ops.fused_moe.layer
|
||||||
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
|
|
||||||
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
|
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
|
||||||
|
# Author: Li Wei, Tang Shiwen
|
||||||
|
# Email: liwei157@baidu.com, tangshiwen@baidu.com
|
||||||
|
# This file is a part of the vllm-kunlun project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
LinearBase,
|
||||||
|
LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||||
|
CompressedTensorsConfig,
|
||||||
|
CompressedTensorsLinearMethod,
|
||||||
|
CompressedTensorsKVCacheMethod,
|
||||||
|
CompressedTensorsLinearTransformMethod,
|
||||||
|
get_linear_transform_schemes,
|
||||||
|
)
|
||||||
|
from vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||||
|
KunlunCompressedTensorsMoEMethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KunlunCompressedTensorsConfig(CompressedTensorsConfig):
|
||||||
|
def get_quant_method(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
# collect schemes
|
||||||
|
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||||
|
input_tfms, output_tfms = get_linear_transform_schemes(
|
||||||
|
layer, prefix, self.transform_config, self.packed_modules_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
# choose quantization method
|
||||||
|
quant_method: LinearMethodBase = UnquantizedLinearMethod()
|
||||||
|
if quant_scheme is not None:
|
||||||
|
layer.scheme = quant_scheme
|
||||||
|
quant_method = CompressedTensorsLinearMethod(self)
|
||||||
|
|
||||||
|
# choose transform method
|
||||||
|
if any((input_tfms, output_tfms)):
|
||||||
|
return CompressedTensorsLinearTransformMethod.from_schemes(
|
||||||
|
quant_method, quant_scheme, input_tfms, output_tfms
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return quant_method
|
||||||
|
|
||||||
|
if isinstance(layer, Attention):
|
||||||
|
return CompressedTensorsKVCacheMethod(self)
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
return KunlunCompressedTensorsMoEMethod.get_moe_method(self, layer)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# monkey patch
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors import (
|
||||||
|
compressed_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
compressed_tensors.CompressedTensorsConfig = KunlunCompressedTensorsConfig
|
||||||
|
print(
|
||||||
|
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors.CompressedTensorsConfig \
|
||||||
|
--> vllm_kunlun.ops.quantization.compressed_tensors.KunlunCompressedTensorsConfig"
|
||||||
|
)
|
||||||
@@ -19,9 +19,95 @@
|
|||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig, FusedMoEMethodBase
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||||
|
CompressedTensorsW4A4MoeMethod,
|
||||||
|
CompressedTensorsW4A8Int8MoEMethod,
|
||||||
CompressedTensorsW8A8Int8MoEMethod,
|
CompressedTensorsW8A8Int8MoEMethod,
|
||||||
|
CompressedTensorsW8A8Int8MoEMethod,
|
||||||
|
CompressedTensorsW8A8Fp8MoEMethod,
|
||||||
|
CompressedTensorsWNA16MoEMethod,
|
||||||
|
find_matched_target,
|
||||||
)
|
)
|
||||||
|
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||||
|
from vllm_kunlun.ops.quantization.kernels.quant_ops import dequant_int4_native
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KunlunCompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
def __init_(self, moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_moe_method(
|
||||||
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> "KunlunCompressedTensorsMoEMethod":
|
||||||
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||||
|
# are supported + check if the layer is being ignored.
|
||||||
|
# Check if a using "Linear" to select schemes
|
||||||
|
if "Linear" in quant_config.target_scheme_map:
|
||||||
|
matched_target = "Linear"
|
||||||
|
else:
|
||||||
|
# May have instead defined the linear layers in the fused model
|
||||||
|
fused_layers = ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"]
|
||||||
|
current_scheme = None
|
||||||
|
for fused_layer in fused_layers:
|
||||||
|
# Check if one of the fused layers are defined in quant_config
|
||||||
|
matched_target = find_matched_target(
|
||||||
|
layer_name=fused_layer,
|
||||||
|
module=layer,
|
||||||
|
targets=quant_config.target_scheme_map.keys(),
|
||||||
|
fused_mapping=quant_config.packed_modules_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only valid if down_proj, gate_proj, and up_proj
|
||||||
|
# are mapped to the same quant scheme in the quant_config
|
||||||
|
if current_scheme is None:
|
||||||
|
current_scheme = quant_config.target_scheme_map.get(matched_target)
|
||||||
|
else:
|
||||||
|
assert current_scheme == quant_config.target_scheme_map.get(
|
||||||
|
matched_target
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_quant = quant_config.target_scheme_map[matched_target].get("weights")
|
||||||
|
input_quant = quant_config.target_scheme_map[matched_target].get(
|
||||||
|
"input_activations"
|
||||||
|
)
|
||||||
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
|
if (
|
||||||
|
weight_quant.strategy in QuantizationStrategy.GROUP
|
||||||
|
and weight_quant.actorder
|
||||||
|
in (ActivationOrdering.GROUP, ActivationOrdering.DYNAMIC)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"WNA16MoE is not supported with actorder=group/dynamic."
|
||||||
|
)
|
||||||
|
# MarlinMoE kernel is not supported on XPU.
|
||||||
|
logger.warning_once(f"Using KunlunCompressedTensorsWNA16MoEMethod")
|
||||||
|
return KunlunCompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config)
|
||||||
|
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW4A4MoeMethod(layer.moe_config)
|
||||||
|
elif (
|
||||||
|
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||||
|
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||||
|
or quant_config._is_fp8_w8a8(weight_quant, input_quant)
|
||||||
|
):
|
||||||
|
return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config)
|
||||||
|
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
|
return KunlunCompressedTensorsW8A8Int8MoEMethod(
|
||||||
|
quant_config, layer.moe_config
|
||||||
|
)
|
||||||
|
elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
|
class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
|
||||||
@@ -184,7 +270,7 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
|||||||
# sort_mode=False,
|
# sort_mode=False,
|
||||||
act=None,
|
act=None,
|
||||||
)
|
)
|
||||||
del x_q, x_scale, sorted_tokens_num_lod,expert_m
|
del x_q, x_scale, sorted_tokens_num_lod, expert_m
|
||||||
|
|
||||||
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
|
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
@@ -202,6 +288,75 @@ class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMetho
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class KunlunCompressedTensorsWNA16MoEMethod(CompressedTensorsWNA16MoEMethod):
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
# dequant packed weights to float16
|
||||||
|
w13_weight = dequant_int4_native(
|
||||||
|
weight_packed_uint8=layer.w13_weight_packed,
|
||||||
|
scale=self.moe_quant_config.w1_scale,
|
||||||
|
)
|
||||||
|
w2_weight = dequant_int4_native(
|
||||||
|
weight_packed_uint8=layer.w2_weight_packed,
|
||||||
|
scale=self.moe_quant_config.w2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.moe.use_ep:
|
||||||
|
return ops.fused_moe_ep(
|
||||||
|
x,
|
||||||
|
w13_weight,
|
||||||
|
w2_weight,
|
||||||
|
router_logits,
|
||||||
|
self.moe.ep_rank,
|
||||||
|
top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ops.fused_moe(
|
||||||
|
x,
|
||||||
|
w13_weight,
|
||||||
|
w2_weight,
|
||||||
|
router_logits,
|
||||||
|
self.moe.ep_rank,
|
||||||
|
top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
inplace=True,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
w1_bias=getattr(layer, "w13_bias", None),
|
||||||
|
w2_bias=getattr(layer, "w2_bias", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# monkey patch
|
# monkey patch
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors import (
|
from vllm.model_executor.layers.quantization.compressed_tensors import (
|
||||||
compressed_tensors_moe,
|
compressed_tensors_moe,
|
||||||
@@ -210,7 +365,21 @@ from vllm.model_executor.layers.quantization.compressed_tensors import (
|
|||||||
compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod = (
|
compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod = (
|
||||||
KunlunCompressedTensorsW8A8Int8MoEMethod
|
KunlunCompressedTensorsW8A8Int8MoEMethod
|
||||||
)
|
)
|
||||||
print(
|
compressed_tensors_moe.CompressedTensorsMoEMethod = KunlunCompressedTensorsMoEMethod
|
||||||
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod \
|
compressed_tensors_moe.CompressedTensorsWNA16MoEMethod = (
|
||||||
--> vllm_kunlun.ops.quantization.compressed_tensors_moe.py:KunlunCompressedTensorsW8A8Int8MoEMethod"
|
KunlunCompressedTensorsWNA16MoEMethod
|
||||||
|
)
|
||||||
|
KunlunCompressedTensorsWNA16MoEMethod.__name__ = "CompressedTensorsWNA16MoEMethod"
|
||||||
|
|
||||||
|
logger.info_once(
|
||||||
|
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod \
|
||||||
|
--> vllm_kunlun.ops.quantization.compressed_tensors_moe.KunlunCompressedTensorsW8A8Int8MoEMethod"
|
||||||
|
)
|
||||||
|
logger.info_once(
|
||||||
|
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||||
|
--> vllm_kunlun.ops.quantization.compressed_tensors_moe.KunlunCompressedTensorsMoEMethod"
|
||||||
|
)
|
||||||
|
logger.info_once(
|
||||||
|
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsWNA16MoEMethod \
|
||||||
|
--> vllm_kunlun.ops.quantization.compressed_tensors_moe.KunlunCompressedTensorsWNA16MoEMethod"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
|
||||||
|
# Author: Li Wei
|
||||||
|
# Email: liwei157@baidu.com
|
||||||
|
# This file is a part of the vllm-kunlun project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import xspeedgate_ops
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||||
|
ExllamaLinearKernel,
|
||||||
|
_POSSIBLE_KERNELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KunlunExllamaLinearKernel(ExllamaLinearKernel):
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
c = self.config
|
||||||
|
|
||||||
|
x_2d = x.reshape(-1, x.shape[-1])
|
||||||
|
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||||
|
|
||||||
|
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
assert w_zp is not None, "Zero points are required by Exllama"
|
||||||
|
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||||
|
output = torch.ops.xspeedgate_ops.gptq_gemm(
|
||||||
|
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias)
|
||||||
|
return output.reshape(out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
# remove ExllamaLinearKernel and add KunlunExllamaLinearKernel
|
||||||
|
_POSSIBLE_KERNELS.remove(ExllamaLinearKernel)
|
||||||
|
_POSSIBLE_KERNELS.append(KunlunExllamaLinearKernel)
|
||||||
@@ -99,12 +99,5 @@ class KunlunScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
# monkey patch
|
# replace CutlassScaledMMLinearKernel with KunlunScaledMMLinearKernel
|
||||||
_POSSIBLE_KERNELS[PlatformEnum.CUDA] = [KunlunScaledMMLinearKernel]
|
_POSSIBLE_KERNELS[PlatformEnum.CUDA] = [KunlunScaledMMLinearKernel]
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import cutlass
|
|
||||||
|
|
||||||
cutlass.CutlassScaledMMLinearKernel = KunlunScaledMMLinearKernel
|
|
||||||
print(
|
|
||||||
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass.CutlassScaledMMLinearKernel \
|
|
||||||
--> vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm.KunlunScaledMMLinearKernel"
|
|
||||||
)
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
|
||||||
# Author: Tang Shiwen
|
# Author: Tang Shiwen, Li Wei
|
||||||
# Email: tangshiwen@baidu.com
|
# Email: tangshiwen@baidu.com, liwei157@baidu.com
|
||||||
# This file is a part of the vllm-kunlun project.
|
# This file is a part of the vllm-kunlun project.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -66,3 +66,21 @@ def dequant_int4(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return fpweight.transpose(1, 2).contiguous()
|
return fpweight.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def dequant_int4_native(weight_packed_uint8: torch.Tensor, scale: torch.Tensor):
|
||||||
|
"""Unpack uint4 weight from packed uint8 weight and dequant it to float16."""
|
||||||
|
weight_upacked_fp16 = (
|
||||||
|
torch.stack(
|
||||||
|
(weight_packed_uint8 & 0xF, (weight_packed_uint8 >> 4) & 0xF),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
.reshape(*weight_packed_uint8.shape[:-1], -1)
|
||||||
|
.contiguous()
|
||||||
|
.to(torch.float16)
|
||||||
|
- 8.0
|
||||||
|
)
|
||||||
|
weight_upacked_fp16 *= scale.repeat(
|
||||||
|
1, 1, weight_upacked_fp16.shape[-1] // scale.shape[-1]
|
||||||
|
)
|
||||||
|
return weight_upacked_fp16
|
||||||
|
|||||||
@@ -2275,7 +2275,7 @@ fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
|||||||
|
|
||||||
|
|
||||||
##################################################
|
##################################################
|
||||||
# --------------- dequant_int4 -----------------
|
# --------------- dequant_int4 -------------------
|
||||||
##################################################
|
##################################################
|
||||||
@custom_op("_C::dequant_int4", mutates_args=())
|
@custom_op("_C::dequant_int4", mutates_args=())
|
||||||
def dequant_int4(
|
def dequant_int4(
|
||||||
|
|||||||
Reference in New Issue
Block a user