diff --git a/requirements.txt b/requirements.txt index 8efb40e..4893576 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ - --index-url https://pip.baidu-int.com/simple/ --trusted-host pip.baidu.com diff --git a/vllm_kunlun/platforms/envs.py b/vllm_kunlun/platforms/envs.py index 8993704..fcf972d 100644 --- a/vllm_kunlun/platforms/envs.py +++ b/vllm_kunlun/platforms/envs.py @@ -68,6 +68,11 @@ xvllm_environment_variables: dict[str, Callable[[], Any]] = { "ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE": lambda: (os.environ.get("ENABLE_VLLM_FUSED_QKV_SPLIT_NORM_ROPE", "False").lower() in ("true", "1")), + + # use int8 bmm + "VLLM_KUNLUN_ENABLE_INT8_BMM": + lambda: (os.environ.get("VLLM_KUNLUN_ENABLE_INT8_BMM", "False").lower() in + ("true", "1")), } # end-env-vars-definition diff --git a/vllm_kunlun/v1/attention/backends/mla/common.py b/vllm_kunlun/v1/attention/backends/mla/common.py index d39b405..a4cd521 100644 --- a/vllm_kunlun/v1/attention/backends/mla/common.py +++ b/vllm_kunlun/v1/attention/backends/mla/common.py @@ -196,6 +196,7 @@ import torch from tqdm import tqdm import vllm.envs as envs +import vllm_kunlun.platforms.envs as vllm_kunlun_envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, @@ -1081,7 +1082,6 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) if is_rocm_aiter_fp8bmm_enabled(): # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = aiter_triton_fp8_bmm(x, @@ -1094,20 +1094,44 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): # Copy result out.copy_(x) else: - # Convert from (B, N * V) to (N, B, V) - out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) + if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM: + x = x.view(-1, self.num_heads, self.kv_lora_rank) + out = out.view(-1, self.num_heads, self.v_head_dim) + q_len = x.shape[0] + extra_params = {"trans": False} + sorted_tokens_num_lod = torch.arange( + self.num_heads + 1, dtype=torch.int, device="cuda" + ) * q_len + sorted_tokens_idx = torch.arange( + self.num_heads * q_len, dtype=torch.int, device="cuda") + xtorch_ops.mla_bmm_I8( + x.contiguous(), # [1, 16, 512] torch.float16 + self.W_UV, # [16, 128, 512] torch.int8 + self.W_UV_SCALE, # [2048, 1] torch.float32 + out, # [1, 16, 128] torch.float16 + sorted_tokens_num_lod, # [17] + sorted_tokens_idx, # [16] + **extra_params + ) + # out_new = out.reshape(-1, self.num_heads * self.v_head_dim) + # out.resize_(origin_out_shape) + # out.copy_(out_new) + else: + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Convert from (B, N * V) to (N, B, V) + out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" - # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape( - -1, self.num_heads * self.v_head_dim) + # Convert from (N, B, V) to (B, N * V) + out_new = out.transpose(0, 1).reshape( + -1, self.num_heads * self.v_head_dim) - # Adjust output buffer shape back to the original (B, N * V) - N, B, V = out.shape - out.resize_((B, N * V)) - out.copy_(out_new) # Copy result + # Adjust output buffer shape back to the original (B, N * V) + N, B, V = out.shape + out.resize_((B, N * V)) + out.copy_(out_new) # Copy result class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): @@ -1339,6 +1363,15 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): f"Layer '{layer}' has no recognized weight attribute:" f" {WEIGHT_NAMES}.") + def get_layer_weight_scale(layer): + WEIGHT_SCALE_NAMES = ("weight_scale",) + for attr in WEIGHT_SCALE_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight scale attribute:" + f" {WEIGHT_SCALE_NAMES}.") + def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) @@ -1353,71 +1386,93 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): return dequant_weights.T return layer.weight + if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM: + kv_b_proj_weight = get_layer_weight(self.kv_b_proj).T + kv_b_proj_weight_scale = get_layer_weight_scale(self.kv_b_proj) + assert kv_b_proj_weight.dtype == torch.int8, \ + f"weight type {kv_b_proj_weight.dtype} not support for int8 MLA BMM" + W_UK, W_UV = kv_b_proj_weight.unflatten( + 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + W_UK_SCALE, W_UV_SCALE = kv_b_proj_weight_scale.unflatten( + 0, (-1, self.qk_nope_head_dim + self.v_head_dim) + ).split([self.qk_nope_head_dim, self.v_head_dim], dim=1) + W_UK_SCALE = W_UK_SCALE / 127.0 + w_uk_dq = W_UK.contiguous().cpu().to(torch.bfloat16).to(kv_b_proj_weight.device) \ + * W_UK_SCALE.contiguous().to(torch.bfloat16) + w_uk_dq_trans = w_uk_dq.transpose(1, 2).contiguous() + self.W_UK_T = W_UK.transpose(1, 2).contiguous() + self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1], + dtype=torch.float, device=kv_b_proj_weight.device) + xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE) + self.W_UV = W_UV.contiguous() + self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1) + else: # we currently do not have quantized bmm's which are needed for # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - if is_rocm_aiter_fp8bmm_enabled(): - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) - for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def gather_and_maybe_dequant_cache_py_optimized( self, @@ -1796,8 +1851,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: @@ -1816,21 +1869,44 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): group_size=128, transpose_bm=True) else: - # Pads the head_dim if necessary (for the underlying kernel) - N, B, P = decode_q_nope.shape - _, _, L = self.W_UK_T.shape - if self.q_pad_num_heads is not None: + if vllm_kunlun_envs.VLLM_KUNLUN_ENABLE_INT8_BMM: + q_len = decode_q_nope.shape[0] decode_ql_nope = decode_q_nope.new_empty( - (self.q_pad_num_heads, B, L)) - decode_ql_nope.resize_((N, B, L)) - + q_len, self.num_heads, self.kv_lora_rank, + dtype=torch.float16, + ) + sorted_tokens_num_lod = torch.arange( + self.num_heads + 1, dtype=torch.int, device="cuda" + ) * q_len + sorted_tokens_idx = torch.arange( + self.num_heads * q_len, dtype=torch.int, device="cuda") + extra_params = {"trans": False} + xtorch_ops.mla_bmm_I8( + decode_q_nope.contiguous(), + self.W_UK_T, + self.W_UK_SCALE, + decode_ql_nope, + sorted_tokens_num_lod, + sorted_tokens_idx, + **extra_params + ) else: - decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + N, B, P = decode_q_nope.shape + _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: + decode_ql_nope = decode_q_nope.new_empty( + (self.q_pad_num_heads, B, L)) + decode_ql_nope.resize_((N, B, L)) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) + else: + decode_ql_nope = decode_q_nope.new_empty((N, B, L)) + + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) if fp8_attention: ql_nope_shape = decode_ql_nope.shape