enable int8 bmm

This commit is contained in:
wzh
2026-01-14 14:30:59 +08:00
parent f0bf384e2e
commit 115eb32068
3 changed files with 164 additions and 84 deletions

View File

@@ -1,4 +1,3 @@
--index-url https://pip.baidu-int.com/simple/
--trusted-host pip.baidu.com

View File

@@ -68,6 +68,11 @@ xvllm_environment_variables: dict[str, Callable[[], Any]] = {
"ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE":
lambda: (os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower() in
("true", "1")),
# use int8 bmm
"VLLM_KUNLUN_ENABLE_INT8_BMM":
lambda: (os.environ.get("VLLM_KUNLUN_ENABLE_INT8_BMM", "False").lower() in
("true", "1")),
}
# end-env-vars-definition

View File

@@ -196,6 +196,7 @@ import torch
from tqdm import tqdm
import vllm.envs as envs
import vllm_kunlun.platforms.envs as vllm_kunlun_envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
@@ -1081,7 +1082,6 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(x,
@@ -1094,6 +1094,30 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
# Copy result
out.copy_(x)
else:
if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
out = out.view(-1, self.num_heads, self.v_head_dim)
q_len = x.shape[0]
extra_params = {"trans": False}
sorted_tokens_num_lod = torch.arange(
self.num_heads + 1, dtype=torch.int, device="cuda"
) * q_len
sorted_tokens_idx = torch.arange(
self.num_heads * q_len, dtype=torch.int, device="cuda")
xtorch_ops.mla_bmm_I8(
x.contiguous(), # [1, 16, 512] torch.float16
self.W_UV, # [16, 128, 512] torch.int8
self.W_UV_SCALE, # [2048, 1] torch.float32
out, # [1, 16, 128] torch.float16
sorted_tokens_num_lod, # [17]
sorted_tokens_idx, # [16]
**extra_params
)
# out_new = out.reshape(-1, self.num_heads * self.v_head_dim)
# out.resize_(origin_out_shape)
# out.copy_(out_new)
else:
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
@@ -1339,6 +1363,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_layer_weight_scale(layer):
WEIGHT_SCALE_NAMES = ("weight_scale",)
for attr in WEIGHT_SCALE_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight scale attribute:"
f" {WEIGHT_SCALE_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
@@ -1353,6 +1386,28 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
return dequant_weights.T
return layer.weight
if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM:
kv_b_proj_weight = get_layer_weight(self.kv_b_proj).T
kv_b_proj_weight_scale = get_layer_weight_scale(self.kv_b_proj)
assert kv_b_proj_weight.dtype == torch.int8, \
f"weight type {kv_b_proj_weight.dtype} not support for int8 MLA BMM"
W_UK, W_UV = kv_b_proj_weight.unflatten(
0, (-1, self.qk_nope_head_dim + self.v_head_dim)
).split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
W_UK_SCALE, W_UV_SCALE = kv_b_proj_weight_scale.unflatten(
0, (-1, self.qk_nope_head_dim + self.v_head_dim)
).split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
W_UK_SCALE = W_UK_SCALE / 127.0
w_uk_dq = W_UK.contiguous().cpu().to(torch.bfloat16).to(kv_b_proj_weight.device) \
* W_UK_SCALE.contiguous().to(torch.bfloat16)
w_uk_dq_trans = w_uk_dq.transpose(1, 2).contiguous()
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
dtype=torch.float, device=kv_b_proj_weight.device)
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
self.W_UV = W_UV.contiguous()
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
else:
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
@@ -1796,8 +1851,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
@@ -1816,7 +1869,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
group_size=128,
transpose_bm=True)
else:
# Pads the head_dim if necessary (for the underlying kernel)
if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM:
q_len = decode_q_nope.shape[0]
decode_ql_nope = decode_q_nope.new_empty(
q_len, self.num_heads, self.kv_lora_rank,
dtype=torch.float16,
)
sorted_tokens_num_lod = torch.arange(
self.num_heads + 1, dtype=torch.int, device="cuda"
) * q_len
sorted_tokens_idx = torch.arange(
self.num_heads * q_len, dtype=torch.int, device="cuda")
extra_params = {"trans": False}
xtorch_ops.mla_bmm_I8(
decode_q_nope.contiguous(),
self.W_UK_T,
self.W_UK_SCALE,
decode_ql_nope,
sorted_tokens_num_lod,
sorted_tokens_idx,
**extra_params
)
else:
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
N, B, P = decode_q_nope.shape
_, _, L = self.W_UK_T.shape
if self.q_pad_num_heads is not None: