[Feature] Support AWQ MoE W4A16 Quantization (#142)
Signed-off-by: tangshiwen <tangshiwen@baidu.com> Co-authored-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from torch.library import register_fake
|
||||
import vllm_kunlun._kunlun
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
def patch_annotations_for_schema(func):
|
||||
"""patch_annotations_for_schema"""
|
||||
sig = inspect.signature(func)
|
||||
@@ -128,7 +129,10 @@ def vllm_kunlun_weak_ref_tensors(
|
||||
return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors)
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
vllm_port=envs.VLLM_PORT
|
||||
|
||||
vllm_port = envs.VLLM_PORT
|
||||
|
||||
|
||||
def _get_open_port() -> int:
|
||||
global vllm_port
|
||||
try:
|
||||
@@ -142,6 +146,7 @@ def _get_open_port() -> int:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
_wrapped = SimpleNamespace(**_orig.__dict__)
|
||||
_wrapped.direct_register_custom_op = direct_register_custom_op
|
||||
_wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
|
||||
@@ -1897,33 +1902,35 @@ def apply_repetition_penalties_(
|
||||
logits: torch.Tensor,
|
||||
prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor,
|
||||
) -> None:
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, logits.size(1))
|
||||
1, logits.size(1)
|
||||
)
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
|
||||
@impl("_C::apply_repetition_penalties_", "CUDA")
|
||||
def apply_repetition_penalties_(
|
||||
logits: torch.Tensor,
|
||||
prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor,
|
||||
) -> None:
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, logits.size(1))
|
||||
1, logits.size(1)
|
||||
)
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
|
||||
|
||||
##################################################
|
||||
# --------------- I8_mqa_logits -----------------
|
||||
##################################################
|
||||
@@ -1937,10 +1944,10 @@ def I8_mqa_logits(
|
||||
logits: torch.Tensor,
|
||||
clean_logits: bool,
|
||||
max_seq_q: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
@@ -1956,6 +1963,7 @@ def I8_mqa_logits(
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@impl("_C::I8_mqa_logits", "CUDA")
|
||||
def I8_mqa_logits_cuda(
|
||||
q: torch.Tensor,
|
||||
@@ -1966,10 +1974,10 @@ def I8_mqa_logits_cuda(
|
||||
logits: torch.Tensor,
|
||||
clean_logits: bool,
|
||||
max_seq_q: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
@@ -1985,6 +1993,7 @@ def I8_mqa_logits_cuda(
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _fake_I8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
@@ -1994,14 +2003,16 @@ def _fake_I8_mqa_logits(
|
||||
logits: torch.Tensor,
|
||||
clean_logits: bool,
|
||||
max_seq_q: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
max_seq_k: Optional[int] = 0,
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
I8_mqa_logits.register_fake(_fake_I8_mqa_logits)
|
||||
|
||||
|
||||
##################################################
|
||||
# ------------- I8_paged_mqa_logits --------------
|
||||
##################################################
|
||||
@@ -2015,7 +2026,8 @@ def I8_paged_mqa_logits(
|
||||
max_context_len: int,
|
||||
clean_logits: bool,
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
@@ -2025,9 +2037,11 @@ def I8_paged_mqa_logits(
|
||||
max_context_len=max_context_len,
|
||||
clean_logits=clean_logits,
|
||||
out=out,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@impl("_C::I8_paged_mqa_logits", "CUDA")
|
||||
def I8_paged_mqa_logits_cuda(
|
||||
q: torch.Tensor,
|
||||
@@ -2038,7 +2052,8 @@ def I8_paged_mqa_logits_cuda(
|
||||
max_context_len: int,
|
||||
clean_logits: bool,
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
@@ -2048,42 +2063,48 @@ def I8_paged_mqa_logits_cuda(
|
||||
max_context_len=max_context_len,
|
||||
clean_logits=clean_logits,
|
||||
out=out,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _fake_I8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_lens: List[torch.Tensor],
|
||||
block_table: torch.Tensor,
|
||||
max_context_len: int,
|
||||
clean_logits: bool,
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
q: torch.Tensor,
|
||||
fused_kv_cache: List[torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
context_lens: List[torch.Tensor],
|
||||
block_table: torch.Tensor,
|
||||
max_context_len: int,
|
||||
clean_logits: bool,
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits)
|
||||
|
||||
|
||||
##################################################
|
||||
# ----------- sparse_prefill_fwd_opt -------------
|
||||
##################################################
|
||||
@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=())
|
||||
def sparse_prefill_fwd_opt(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
@@ -2098,25 +2119,28 @@ def sparse_prefill_fwd_opt(
|
||||
kvlod_xpu=kvlod_xpu,
|
||||
d_v=d_v,
|
||||
is_causal=is_causal,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@impl("_C::sparse_prefill_fwd_opt", "CUDA")
|
||||
def sparse_prefill_fwd_opt_cuda(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
@@ -2131,46 +2155,52 @@ def sparse_prefill_fwd_opt_cuda(
|
||||
kvlod_xpu=kvlod_xpu,
|
||||
d_v=d_v,
|
||||
is_causal=is_causal,
|
||||
use_xfa_boost=use_xfa_boost)
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _fake_sparse_prefill_fwd_opt(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False) -> None:
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qlod_cpu: Optional[torch.Tensor] = None,
|
||||
qlod_xpu: Optional[torch.Tensor] = None,
|
||||
kvlod_cpu: Optional[torch.Tensor] = None,
|
||||
kvlod_xpu: Optional[torch.Tensor] = None,
|
||||
d_v: Optional[int] = -1,
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt)
|
||||
|
||||
|
||||
##################################################
|
||||
# ------------------ fwd_kvcache_mla -------------
|
||||
##################################################
|
||||
@custom_op("_C::fwd_kvcache_mla", mutates_args=())
|
||||
def fwd_kvcache_mla(
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
@@ -2184,24 +2214,27 @@ def fwd_kvcache_mla(
|
||||
q_r=q_r,
|
||||
pe_cache=pe_cache,
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
kv_lod_xpu=kv_lod_xpu)
|
||||
kv_lod_xpu=kv_lod_xpu,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@impl("_C::fwd_kvcache_mla", "CUDA")
|
||||
def fwd_kvcache_mla_cuda(
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
@@ -2215,27 +2248,94 @@ def fwd_kvcache_mla_cuda(
|
||||
q_r=q_r,
|
||||
pe_cache=pe_cache,
|
||||
use_xfa_boost=use_xfa_boost,
|
||||
kv_lod_xpu=kv_lod_xpu)
|
||||
kv_lod_xpu=kv_lod_xpu,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _fake_fwd_kvcache_mla(
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
|
||||
q_c: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
kv_lod_cpu: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
p_sums: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
max_seq_kv: int,
|
||||
q_r: Optional[torch.Tensor] = None,
|
||||
pe_cache: Optional[torch.Tensor] = None,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
||||
|
||||
|
||||
##################################################
|
||||
# --------------- dequant_int4 -----------------
|
||||
##################################################
|
||||
@custom_op("_C::dequant_int4", mutates_args=())
|
||||
def dequant_int4(
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
group_m: int,
|
||||
int4_signed: bool = True,
|
||||
use_mode_fast: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.dequant_int4(
|
||||
x=x,
|
||||
scale=scale,
|
||||
zero=zero,
|
||||
y=y,
|
||||
group_m=group_m,
|
||||
int4_signed=int4_signed,
|
||||
use_mode_fast=use_mode_fast,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@impl("_C::dequant_int4", "CUDA")
|
||||
def dequant_int4_cuda(
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
group_m: int,
|
||||
int4_signed: bool = True,
|
||||
use_mode_fast: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.dequant_int4(
|
||||
x=x,
|
||||
scale=scale,
|
||||
zero=zero,
|
||||
y=y,
|
||||
group_m=group_m,
|
||||
int4_signed=int4_signed,
|
||||
use_mode_fast=use_mode_fast,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _fake_dequant_int4(
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
zero: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
group_m: int,
|
||||
int4_signed: bool = True,
|
||||
use_mode_fast: bool = False,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
dequant_int4.register_fake(_fake_dequant_int4)
|
||||
|
||||
|
||||
##################################################
|
||||
# ------------------ fast_topkv2 -------------
|
||||
##################################################
|
||||
|
||||
Reference in New Issue
Block a user