[Feature] Support AWQ MoE W4A16 Quantization (#142)

Signed-off-by: tangshiwen <tangshiwen@baidu.com>
Co-authored-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
Shiwen Tang
2026-01-26 18:56:05 +08:00
committed by GitHub
parent 2a998286c0
commit 0711c1abfa
7 changed files with 639 additions and 126 deletions

View File

@@ -12,6 +12,7 @@ from torch.library import register_fake
import vllm_kunlun._kunlun
import vllm.envs as envs
def patch_annotations_for_schema(func):
"""patch_annotations_for_schema"""
sig = inspect.signature(func)
@@ -128,7 +129,10 @@ def vllm_kunlun_weak_ref_tensors(
return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")
vllm_port=envs.VLLM_PORT
vllm_port = envs.VLLM_PORT
def _get_open_port() -> int:
global vllm_port
try:
@@ -142,6 +146,7 @@ def _get_open_port() -> int:
s.bind(("", 0))
return s.getsockname()[1]
_wrapped = SimpleNamespace(**_orig.__dict__)
_wrapped.direct_register_custom_op = direct_register_custom_op
_wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
@@ -1897,33 +1902,35 @@ def apply_repetition_penalties_(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor
repetition_penalties: torch.Tensor,
) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1))
1, logits.size(1)
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
@impl("_C::apply_repetition_penalties_", "CUDA")
def apply_repetition_penalties_(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor
repetition_penalties: torch.Tensor,
) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1))
1, logits.size(1)
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
##################################################
# --------------- I8_mqa_logits -----------------
##################################################
@@ -1937,10 +1944,10 @@ def I8_mqa_logits(
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
@@ -1956,6 +1963,7 @@ def I8_mqa_logits(
)
return None
@impl("_C::I8_mqa_logits", "CUDA")
def I8_mqa_logits_cuda(
q: torch.Tensor,
@@ -1966,10 +1974,10 @@ def I8_mqa_logits_cuda(
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
@@ -1985,6 +1993,7 @@ def I8_mqa_logits_cuda(
)
return None
def _fake_I8_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
@@ -1994,14 +2003,16 @@ def _fake_I8_mqa_logits(
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
use_xfa_boost: Optional[bool] = False,
) -> None:
return None
I8_mqa_logits.register_fake(_fake_I8_mqa_logits)
##################################################
# ------------- I8_paged_mqa_logits --------------
##################################################
@@ -2015,7 +2026,8 @@ def I8_paged_mqa_logits(
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
@@ -2025,9 +2037,11 @@ def I8_paged_mqa_logits(
max_context_len=max_context_len,
clean_logits=clean_logits,
out=out,
use_xfa_boost=use_xfa_boost)
use_xfa_boost=use_xfa_boost,
)
return None
@impl("_C::I8_paged_mqa_logits", "CUDA")
def I8_paged_mqa_logits_cuda(
q: torch.Tensor,
@@ -2038,7 +2052,8 @@ def I8_paged_mqa_logits_cuda(
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
@@ -2048,42 +2063,48 @@ def I8_paged_mqa_logits_cuda(
max_context_len=max_context_len,
clean_logits=clean_logits,
out=out,
use_xfa_boost=use_xfa_boost)
use_xfa_boost=use_xfa_boost,
)
return None
def _fake_I8_paged_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False,
) -> None:
return None
I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits)
##################################################
# ----------- sparse_prefill_fwd_opt -------------
##################################################
@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=())
def sparse_prefill_fwd_opt(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
@@ -2098,25 +2119,28 @@ def sparse_prefill_fwd_opt(
kvlod_xpu=kvlod_xpu,
d_v=d_v,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost)
use_xfa_boost=use_xfa_boost,
)
return None
@impl("_C::sparse_prefill_fwd_opt", "CUDA")
def sparse_prefill_fwd_opt_cuda(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
@@ -2131,46 +2155,52 @@ def sparse_prefill_fwd_opt_cuda(
kvlod_xpu=kvlod_xpu,
d_v=d_v,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost)
use_xfa_boost=use_xfa_boost,
)
return None
def _fake_sparse_prefill_fwd_opt(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
) -> None:
return None
sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt)
##################################################
# ------------------ fwd_kvcache_mla -------------
##################################################
@custom_op("_C::fwd_kvcache_mla", mutates_args=())
def fwd_kvcache_mla(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
@@ -2184,24 +2214,27 @@ def fwd_kvcache_mla(
q_r=q_r,
pe_cache=pe_cache,
use_xfa_boost=use_xfa_boost,
kv_lod_xpu=kv_lod_xpu)
kv_lod_xpu=kv_lod_xpu,
)
return None
@impl("_C::fwd_kvcache_mla", "CUDA")
def fwd_kvcache_mla_cuda(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
@@ -2215,27 +2248,94 @@ def fwd_kvcache_mla_cuda(
q_r=q_r,
pe_cache=pe_cache,
use_xfa_boost=use_xfa_boost,
kv_lod_xpu=kv_lod_xpu)
kv_lod_xpu=kv_lod_xpu,
)
return None
def _fake_fwd_kvcache_mla(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None,
) -> None:
return None
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
##################################################
# --------------- dequant_int4 -----------------
##################################################
@custom_op("_C::dequant_int4", mutates_args=())
def dequant_int4(
x: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
y: torch.Tensor,
group_m: int,
int4_signed: bool = True,
use_mode_fast: bool = False,
) -> None:
xtorch_ops.dequant_int4(
x=x,
scale=scale,
zero=zero,
y=y,
group_m=group_m,
int4_signed=int4_signed,
use_mode_fast=use_mode_fast,
)
return None
@impl("_C::dequant_int4", "CUDA")
def dequant_int4_cuda(
x: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
y: torch.Tensor,
group_m: int,
int4_signed: bool = True,
use_mode_fast: bool = False,
) -> None:
xtorch_ops.dequant_int4(
x=x,
scale=scale,
zero=zero,
y=y,
group_m=group_m,
int4_signed=int4_signed,
use_mode_fast=use_mode_fast,
)
return None
def _fake_dequant_int4(
x: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
y: torch.Tensor,
group_m: int,
int4_signed: bool = True,
use_mode_fast: bool = False,
) -> None:
return None
dequant_int4.register_fake(_fake_dequant_int4)
##################################################
# ------------------ fast_topkv2 -------------
##################################################