@@ -68,6 +68,11 @@ xvllm_environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE":
|
"ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE":
|
||||||
lambda: (os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower() in
|
lambda: (os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# use int8 bmm
|
||||||
|
"VLLM_KUNLUN_ENABLE_INT8_BMM":
|
||||||
|
lambda: (os.environ.get("VLLM_KUNLUN_ENABLE_INT8_BMM", "False").lower() in
|
||||||
|
("true", "1")),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@@ -196,6 +196,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm_kunlun.platforms.envs as vllm_kunlun_envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
@@ -1081,7 +1082,6 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
|
|
||||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
|
||||||
if is_rocm_aiter_fp8bmm_enabled():
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||||
x = aiter_triton_fp8_bmm(x,
|
x = aiter_triton_fp8_bmm(x,
|
||||||
@@ -1094,20 +1094,44 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
# Copy result
|
# Copy result
|
||||||
out.copy_(x)
|
out.copy_(x)
|
||||||
else:
|
else:
|
||||||
# Convert from (B, N * V) to (N, B, V)
|
if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM:
|
||||||
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||||
|
out = out.view(-1, self.num_heads, self.v_head_dim)
|
||||||
|
q_len = x.shape[0]
|
||||||
|
extra_params = {"trans": False}
|
||||||
|
sorted_tokens_num_lod = torch.arange(
|
||||||
|
self.num_heads + 1, dtype=torch.int, device="cuda"
|
||||||
|
) * q_len
|
||||||
|
sorted_tokens_idx = torch.arange(
|
||||||
|
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||||
|
xtorch_ops.mla_bmm_I8(
|
||||||
|
x.contiguous(), # [1, 16, 512] torch.float16
|
||||||
|
self.W_UV, # [16, 128, 512] torch.int8
|
||||||
|
self.W_UV_SCALE, # [2048, 1] torch.float32
|
||||||
|
out, # [1, 16, 128] torch.float16
|
||||||
|
sorted_tokens_num_lod, # [17]
|
||||||
|
sorted_tokens_idx, # [16]
|
||||||
|
**extra_params
|
||||||
|
)
|
||||||
|
# out_new = out.reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
|
# out.resize_(origin_out_shape)
|
||||||
|
# out.copy_(out_new)
|
||||||
|
else:
|
||||||
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
|
# Convert from (B, N * V) to (N, B, V)
|
||||||
|
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
||||||
|
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||||
|
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
out_new = out.transpose(0, 1).reshape(
|
out_new = out.transpose(0, 1).reshape(
|
||||||
-1, self.num_heads * self.v_head_dim)
|
-1, self.num_heads * self.v_head_dim)
|
||||||
|
|
||||||
# Adjust output buffer shape back to the original (B, N * V)
|
# Adjust output buffer shape back to the original (B, N * V)
|
||||||
N, B, V = out.shape
|
N, B, V = out.shape
|
||||||
out.resize_((B, N * V))
|
out.resize_((B, N * V))
|
||||||
out.copy_(out_new) # Copy result
|
out.copy_(out_new) # Copy result
|
||||||
|
|
||||||
|
|
||||||
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||||
@@ -1339,6 +1363,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
f"Layer '{layer}' has no recognized weight attribute:"
|
f"Layer '{layer}' has no recognized weight attribute:"
|
||||||
f" {WEIGHT_NAMES}.")
|
f" {WEIGHT_NAMES}.")
|
||||||
|
|
||||||
|
def get_layer_weight_scale(layer):
|
||||||
|
WEIGHT_SCALE_NAMES = ("weight_scale",)
|
||||||
|
for attr in WEIGHT_SCALE_NAMES:
|
||||||
|
if hasattr(layer, attr):
|
||||||
|
return getattr(layer, attr)
|
||||||
|
raise AttributeError(
|
||||||
|
f"Layer '{layer}' has no recognized weight scale attribute:"
|
||||||
|
f" {WEIGHT_SCALE_NAMES}.")
|
||||||
|
|
||||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||||
# NOTE: This should only be used offline, since it's O(N^3)
|
# NOTE: This should only be used offline, since it's O(N^3)
|
||||||
@@ -1353,71 +1386,93 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
return dequant_weights.T
|
return dequant_weights.T
|
||||||
return layer.weight
|
return layer.weight
|
||||||
|
|
||||||
|
if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM:
|
||||||
|
kv_b_proj_weight = get_layer_weight(self.kv_b_proj).T
|
||||||
|
kv_b_proj_weight_scale = get_layer_weight_scale(self.kv_b_proj)
|
||||||
|
assert kv_b_proj_weight.dtype == torch.int8, \
|
||||||
|
f"weight type {kv_b_proj_weight.dtype} not support for int8 MLA BMM"
|
||||||
|
W_UK, W_UV = kv_b_proj_weight.unflatten(
|
||||||
|
0, (-1, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
).split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
|
||||||
|
W_UK_SCALE, W_UV_SCALE = kv_b_proj_weight_scale.unflatten(
|
||||||
|
0, (-1, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
).split([self.qk_nope_head_dim, self.v_head_dim], dim=1)
|
||||||
|
W_UK_SCALE = W_UK_SCALE / 127.0
|
||||||
|
w_uk_dq = W_UK.contiguous().cpu().to(torch.bfloat16).to(kv_b_proj_weight.device) \
|
||||||
|
* W_UK_SCALE.contiguous().to(torch.bfloat16)
|
||||||
|
w_uk_dq_trans = w_uk_dq.transpose(1, 2).contiguous()
|
||||||
|
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
|
||||||
|
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
|
||||||
|
dtype=torch.float, device=kv_b_proj_weight.device)
|
||||||
|
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||||
|
self.W_UV = W_UV.contiguous()
|
||||||
|
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
|
||||||
|
else:
|
||||||
# we currently do not have quantized bmm's which are needed for
|
# we currently do not have quantized bmm's which are needed for
|
||||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||||
assert kv_b_proj_weight.shape == (
|
assert kv_b_proj_weight.shape == (
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||||
f"{kv_b_proj_weight.shape=}, "
|
f"{kv_b_proj_weight.shape=}, "
|
||||||
f"{self.kv_lora_rank=}, "
|
f"{self.kv_lora_rank=}, "
|
||||||
f"{self.num_heads=}, "
|
f"{self.num_heads=}, "
|
||||||
f"{self.qk_nope_head_dim=}, "
|
f"{self.qk_nope_head_dim=}, "
|
||||||
f"{self.v_head_dim=}")
|
f"{self.v_head_dim=}")
|
||||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||||
self.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.qk_nope_head_dim + self.v_head_dim,
|
self.qk_nope_head_dim + self.v_head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
W_UK, W_UV = kv_b_proj_weight.split(
|
W_UK, W_UV = kv_b_proj_weight.split(
|
||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
if is_rocm_aiter_fp8bmm_enabled():
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||||
W_K, dtype=current_platform.fp8_dtype())
|
W_K, dtype=current_platform.fp8_dtype())
|
||||||
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
||||||
W_V, dtype=current_platform.fp8_dtype())
|
W_V, dtype=current_platform.fp8_dtype())
|
||||||
|
|
||||||
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
||||||
# triton kernel to avoid runtime compilation for unseen batch sizes
|
# triton kernel to avoid runtime compilation for unseen batch sizes
|
||||||
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
||||||
# On DS-R1, this step adds roughly 50s to the model loading time.
|
# On DS-R1, this step adds roughly 50s to the model loading time.
|
||||||
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
||||||
pre_compilation_list = list(range(1, max_batch_size + 1))
|
pre_compilation_list = list(range(1, max_batch_size + 1))
|
||||||
if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
pre_compilation_list = tqdm(
|
pre_compilation_list = tqdm(
|
||||||
pre_compilation_list,
|
pre_compilation_list,
|
||||||
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
||||||
total=max_batch_size,
|
total=max_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for m in pre_compilation_list:
|
for m in pre_compilation_list:
|
||||||
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
|
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
device=self.W_K.device)
|
device=self.W_K.device)
|
||||||
aiter_triton_fp8_bmm(x,
|
aiter_triton_fp8_bmm(x,
|
||||||
self.W_K,
|
self.W_K,
|
||||||
self.W_K_scale,
|
self.W_K_scale,
|
||||||
group_size=128,
|
group_size=128,
|
||||||
transpose_bm=True)
|
transpose_bm=True)
|
||||||
|
|
||||||
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
|
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
device=self.W_V.device)
|
device=self.W_V.device)
|
||||||
aiter_triton_fp8_bmm(x,
|
aiter_triton_fp8_bmm(x,
|
||||||
self.W_V,
|
self.W_V,
|
||||||
self.W_V_scale,
|
self.W_V_scale,
|
||||||
group_size=128,
|
group_size=128,
|
||||||
transpose_bm=True)
|
transpose_bm=True)
|
||||||
else:
|
else:
|
||||||
# Convert from (L, N, V) to (N, L, V)
|
# Convert from (L, N, V) to (N, L, V)
|
||||||
self.W_UV = W_UV.transpose(0, 1)
|
self.W_UV = W_UV.transpose(0, 1)
|
||||||
# Convert from (L, N, P) to (N, P, L)
|
# Convert from (L, N, P) to (N, P, L)
|
||||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||||
|
|
||||||
def gather_and_maybe_dequant_cache_py_optimized(
|
def gather_and_maybe_dequant_cache_py_optimized(
|
||||||
self,
|
self,
|
||||||
@@ -1796,8 +1851,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
decode_q_nope, decode_q_pe = decode_q.split(
|
decode_q_nope, decode_q_pe = decode_q.split(
|
||||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
# Convert from (B, N, P) to (N, B, P)
|
|
||||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
|
||||||
|
|
||||||
# Pads the head_dim if necessary (for the underlying kernel)
|
# Pads the head_dim if necessary (for the underlying kernel)
|
||||||
if self.q_pad_num_heads is not None:
|
if self.q_pad_num_heads is not None:
|
||||||
@@ -1816,21 +1869,44 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
group_size=128,
|
group_size=128,
|
||||||
transpose_bm=True)
|
transpose_bm=True)
|
||||||
else:
|
else:
|
||||||
# Pads the head_dim if necessary (for the underlying kernel)
|
if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM:
|
||||||
N, B, P = decode_q_nope.shape
|
q_len = decode_q_nope.shape[0]
|
||||||
_, _, L = self.W_UK_T.shape
|
|
||||||
if self.q_pad_num_heads is not None:
|
|
||||||
decode_ql_nope = decode_q_nope.new_empty(
|
decode_ql_nope = decode_q_nope.new_empty(
|
||||||
(self.q_pad_num_heads, B, L))
|
q_len, self.num_heads, self.kv_lora_rank,
|
||||||
decode_ql_nope.resize_((N, B, L))
|
dtype=torch.float16,
|
||||||
|
)
|
||||||
|
sorted_tokens_num_lod = torch.arange(
|
||||||
|
self.num_heads + 1, dtype=torch.int, device="cuda"
|
||||||
|
) * q_len
|
||||||
|
sorted_tokens_idx = torch.arange(
|
||||||
|
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||||
|
extra_params = {"trans": False}
|
||||||
|
xtorch_ops.mla_bmm_I8(
|
||||||
|
decode_q_nope.contiguous(),
|
||||||
|
self.W_UK_T,
|
||||||
|
self.W_UK_SCALE,
|
||||||
|
decode_ql_nope,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
sorted_tokens_idx,
|
||||||
|
**extra_params
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
|
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||||
|
N, B, P = decode_q_nope.shape
|
||||||
|
_, _, L = self.W_UK_T.shape
|
||||||
|
if self.q_pad_num_heads is not None:
|
||||||
|
decode_ql_nope = decode_q_nope.new_empty(
|
||||||
|
(self.q_pad_num_heads, B, L))
|
||||||
|
decode_ql_nope.resize_((N, B, L))
|
||||||
|
|
||||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
else:
|
||||||
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
||||||
# Convert from (N, B, L) to (B, N, L)
|
|
||||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
|
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
||||||
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
|
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||||
|
|
||||||
if fp8_attention:
|
if fp8_attention:
|
||||||
ql_nope_shape = decode_ql_nope.shape
|
ql_nope_shape = decode_ql_nope.shape
|
||||||
|
|||||||
Reference in New Issue
Block a user