diff --git a/docs/platforms/ascend_npu.md b/docs/platforms/ascend_npu.md index 77d073acb..b9840ad97 100644 --- a/docs/platforms/ascend_npu.md +++ b/docs/platforms/ascend_npu.md @@ -118,7 +118,7 @@ git clone https://github.com/sgl-project/sglang.git cd sglang/docker # Build the docker image -docker build -t sglang-npu:main -f Dockerfile.npu . +docker build -t -f Dockerfile.npu . alias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-size=16g \ --device=/dev/davinci0 --device=/dev/davinci1 --device=/dev/davinci2 --device=/dev/davinci3 \ @@ -132,7 +132,7 @@ alias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-siz --volume /var/queue_schedule:/var/queue_schedule --volume ~/.cache/:/root/.cache/' drun --env "HF_TOKEN=" \ - sglang-npu:main \ + \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --attention-backend ascend --host 0.0.0.0 --port 30000 ``` @@ -149,7 +149,7 @@ Prefill: export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export ASCEND_MF_STORE_URL="tcp://:" -drun sglang-npu:main \ +drun \ python3 -m sglang.launch_server --model-path State_Cloud/DeepSeek-R1-bf16-hfd-w8a8 \ --trust-remote-code \ --attention-backend ascend \ @@ -174,8 +174,9 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export ASCEND_MF_STORE_URL="tcp://:" export HCCL_BUFFSIZE=200 export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=24 +export SGLANG_NPU_USE_MLAPO=1 -drun sglang-npu:main \ +drun \ python3 -m sglang.launch_server --model-path State_Cloud/DeepSeek-R1-bf16-hfd-w8a8 \ --trust-remote-code \ --attention-backend ascend \ @@ -198,7 +199,7 @@ drun sglang-npu:main \ Mini_LB: ```shell -drun sglang-npu:main \ +drun \ python -m sglang.srt.disaggregation.launch_lb \ --prefill http://:8000 \ --decode http://:8001 \ diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 6d5ed0a5c..52192b7bc 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -9,6 +9,7 @@ from torch.nn.functional import scaled_dot_product_attention from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType @@ -401,7 +402,7 @@ class AscendAttnBackend(AttentionBackend): antiquant_scale=None, sparse_mode=0, ) - output = torch.zeros_like(q_nope, dtype=q.dtype, device=q.device) + output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device) softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) torch_npu.npu_fused_infer_attention_score.out( @@ -437,6 +438,10 @@ class AscendAttnBackend(AttentionBackend): q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, ): + if is_mla_preprocess_enabled(): + # MLAPO does saving kv_cache + save_kv_cache = False + if self.graph_mode: return self.forward_decode_graph( q, diff --git a/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py b/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py new file mode 100644 index 000000000..84efe2ce4 --- /dev/null +++ b/python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py @@ -0,0 +1,300 @@ +import torch +import torch.nn.functional as F + +from sglang.srt.utils import get_bool_env_var, is_npu + +_is_npu = is_npu() +_ENABLE_MLA_PREPROCESS_FLAG = get_bool_env_var("SGLANG_NPU_USE_MLAPO") +_NPU_FORMAT_NZ = 29 + + +def is_mla_preprocess_enabled() -> bool: + return _is_npu and _ENABLE_MLA_PREPROCESS_FLAG + + +if is_mla_preprocess_enabled(): + import sgl_kernel_npu + import torch_npu + + torch.npu.config.allow_internal_format = True + torch.npu.set_compile_mode(jit_compile=False) + + +def round_up(val: int, align: int) -> int: + if align == 0: + return 0 + return -(val // -align) * align + + +def transdata(nd_mat, block_size: tuple = (16, 16)): + r = round_up(nd_mat.shape[0], block_size[0]) + c = round_up(nd_mat.shape[1], block_size[1]) + r_pad = r - nd_mat.shape[0] + c_pad = c - nd_mat.shape[1] + nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad))) + nz_mat = torch.permute( + torch.reshape( + nd_mat, + (r // block_size[0], block_size[0], c // block_size[1], block_size[1]), + ), + [2, 0, 1, 3], + ) + nz_mat = torch.reshape( + nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]) + ) + return nz_mat + + +def trans_rope_weight(weight, rope_dim): + weight_1 = weight[..., -rope_dim::2, :].contiguous() + weight_2 = weight[..., -rope_dim + 1 :: 2, :].contiguous() + weight[..., -rope_dim:, :] = torch.cat([weight_1, weight_2], dim=-2) + + return weight.contiguous() + + +class NPUFusedMLAPreprocess(torch.nn.Module): + def __init__( + self, + fused_qkv_a_proj_with_mqa, + q_a_layernorm, + kv_a_layernorm, + q_b_proj, + w_kc, + rotary_emb, + layer_id, + num_local_heads, + qk_nope_head_dim, + qk_rope_head_dim, + ): + super().__init__() + self.qkv_a_proj = fused_qkv_a_proj_with_mqa + self.q_a_layernorm = q_a_layernorm + self.kv_a_layernorm = kv_a_layernorm + self.q_b_proj = q_b_proj + self.w_kc = w_kc.contiguous() + self.rotary_emb = rotary_emb + self.layer_id = layer_id + self.has_preprocess_weights = False + + self.q_lora_rank = self.q_b_proj.input_size # 1536 + self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512 + self.num_local_heads = num_local_heads # tp + self.qk_nope_head_dim = qk_nope_head_dim # 128 + self.qk_rope_head_dim = qk_rope_head_dim # 64 + + def preprocess_weights(self, hidden_states): + self.dummy = torch.empty( + (hidden_states.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + self.qkv_a_proj_input_offset = self.qkv_a_proj.input_offset.to(dtype=torch.int8) + self.q_b_proj_input_offset = self.q_b_proj.input_offset.to(dtype=torch.int8) + + # matmul_0 weight [7168, 2112] + fused_qkv_a_proj_with_mqa_weight_q = self.qkv_a_proj.weight.data[ + :, : self.q_lora_rank + ].clone() # [7168, 1536] + fused_qkv_a_proj_with_mqa_weight_kv = self.qkv_a_proj.weight.data[ + :, self.q_lora_rank : + ].clone() # [7168, 576] + # rope fit + fused_qkv_a_proj_with_mqa_weight_kv_t = ( + fused_qkv_a_proj_with_mqa_weight_kv.t().contiguous() + ) + fused_qkv_a_proj_with_mqa_weight_kv_t = trans_rope_weight( + fused_qkv_a_proj_with_mqa_weight_kv_t, self.qk_rope_head_dim + ) + fused_qkv_a_proj_with_mqa_weight_kv = ( + fused_qkv_a_proj_with_mqa_weight_kv_t.t().contiguous() + ) + # cat nz + fused_qkv_a_proj_with_mqa_weight_new = torch.cat( + (fused_qkv_a_proj_with_mqa_weight_kv, fused_qkv_a_proj_with_mqa_weight_q), + dim=-1, + ) + fused_qkv_a_proj_with_mqa_weight = ( + fused_qkv_a_proj_with_mqa_weight_new.t().contiguous() + ) + fused_qkv_a_proj_with_mqa_weight_nz = ( + transdata(fused_qkv_a_proj_with_mqa_weight, block_size=(16, 32)) + .unsqueeze(0) + .contiguous() + ) + self.qkv_a_proj_weight_nz = torch_npu.npu_format_cast( + fused_qkv_a_proj_with_mqa_weight_nz, _NPU_FORMAT_NZ + ) + + # matmul_0 deq_scale [2112] + fused_qkv_a_proj_with_mqa_deq_scale_q = self.qkv_a_proj.deq_scale.data[ + : self.q_lora_rank + ].clone() # [7168, 1536] + fused_qkv_a_proj_with_mqa_deq_scale_kv = self.qkv_a_proj.deq_scale.data[ + self.q_lora_rank : + ].clone() # [7168, 576] + # rope fit + fused_qkv_a_proj_with_mqa_deq_scale_kv = ( + fused_qkv_a_proj_with_mqa_deq_scale_kv.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1 + ).contiguous() + ) + fused_qkv_a_proj_with_mqa_deq_scale_kv = trans_rope_weight( + fused_qkv_a_proj_with_mqa_deq_scale_kv, self.qk_rope_head_dim + ) + fused_qkv_a_proj_with_mqa_deq_scale_kv = ( + fused_qkv_a_proj_with_mqa_deq_scale_kv.view( + self.kv_lora_rank + self.qk_rope_head_dim + ).contiguous() + ) + self.qkv_a_proj_deq_scale_kvq = torch.cat( + ( + fused_qkv_a_proj_with_mqa_deq_scale_kv, + fused_qkv_a_proj_with_mqa_deq_scale_q, + ), + dim=-1, + ) + + # matmul_0 quant_bias [2112] + fused_qkv_a_proj_with_mqa_quant_bias_q = self.qkv_a_proj.quant_bias.data[ + : self.q_lora_rank + ].clone() # [7168, 1536] + fused_qkv_a_proj_with_mqa_quant_bias_kv = self.qkv_a_proj.quant_bias.data[ + self.q_lora_rank : + ].clone() # [7168, 576] + # rope fit + fused_qkv_a_proj_with_mqa_quant_bias_kv = ( + fused_qkv_a_proj_with_mqa_quant_bias_kv.reshape( + self.kv_lora_rank + self.qk_rope_head_dim, -1 + ).contiguous() + ) + fused_qkv_a_proj_with_mqa_quant_bias_kv = trans_rope_weight( + fused_qkv_a_proj_with_mqa_quant_bias_kv, self.qk_rope_head_dim + ) + fused_qkv_a_proj_with_mqa_quant_bias_kv = ( + fused_qkv_a_proj_with_mqa_quant_bias_kv.view( + self.kv_lora_rank + self.qk_rope_head_dim + ).contiguous() + ) + self.qkv_a_proj_quant_bias_kvq = torch.cat( + ( + fused_qkv_a_proj_with_mqa_quant_bias_kv, + fused_qkv_a_proj_with_mqa_quant_bias_q, + ), + dim=-1, + ) + + # matmul_1 weight [1536, num_head * 192] + q_b_proj_weight = self.q_b_proj.weight.data.clone() + q_b_proj_weight = q_b_proj_weight.t().reshape( + self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1 + ) + q_b_proj_weight = trans_rope_weight(q_b_proj_weight, self.qk_rope_head_dim) + q_b_proj_weight = q_b_proj_weight.reshape( + self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1 + ) + q_b_proj_weight_nz = ( + transdata(q_b_proj_weight, block_size=(16, 32)).unsqueeze(0).contiguous() + ) + self.q_b_proj_weight_nz = torch_npu.npu_format_cast( + q_b_proj_weight_nz, _NPU_FORMAT_NZ + ) + + # matmul_1 deq_scale [num_head * 192] + q_b_proj_deq_scale = self.q_b_proj.deq_scale.data.clone() + q_b_proj_deq_scale = q_b_proj_deq_scale.reshape( + self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1 + ) + q_b_proj_deq_scale = trans_rope_weight( + q_b_proj_deq_scale, self.qk_rope_head_dim + ) + self.q_b_proj_deq_scale = q_b_proj_deq_scale.reshape( + self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim) + ) + + # matmul_1 quant_bias [num_head * 192] + q_b_proj_quant_bias = self.q_b_proj.quant_bias.data.clone() + q_b_proj_quant_bias = q_b_proj_quant_bias.reshape( + self.num_local_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1 + ) + q_b_proj_quant_bias = trans_rope_weight( + q_b_proj_quant_bias, self.qk_rope_head_dim + ) + self.q_b_proj_quant_bias = q_b_proj_quant_bias.reshape( + self.num_local_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim) + ) + + def get_sin_cos(self, positions): + cos_sin = self.rotary_emb.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2) + sin = sin.repeat(1, 2) + return cos, sin + + def get_kv_cache_and_cache_idx(self, forward_batch): + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(self.layer_id) + slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32) + return k_cache, v_cache, slot_mapping + + def forward(self, positions, hidden_states, forward_batch, zero_allocator): + input_dtype = hidden_states.dtype + if not self.has_preprocess_weights: + self.preprocess_weights(hidden_states) + self.has_preprocess_weights = True + self.dtype = hidden_states.dtype + + cos, sin = self.get_sin_cos(positions) + k_cache, v_cache, slot_mapping = self.get_kv_cache_and_cache_idx(forward_batch) + + q_nope_out = torch.empty( + (hidden_states.shape[0], self.w_kc.shape[0], k_cache.shape[-1]), + dtype=input_dtype, + device=hidden_states.device, + ) + q_rope_out = torch.empty( + (hidden_states.shape[0], self.w_kc.shape[0], v_cache.shape[-1]), + dtype=input_dtype, + device=hidden_states.device, + ) + + # TODO: dummy inputs to be removed + # https://github.com/sgl-project/sgl-kernel-npu/issues/78 + torch.ops.npu.mla_preprocess( + hidden_states, + self.dummy, + self.dummy, + self.qkv_a_proj_weight_nz, + self.qkv_a_proj_deq_scale_kvq, + self.q_a_layernorm.weight, + self.q_a_layernorm.bias, + self.q_b_proj_weight_nz, + self.q_b_proj_deq_scale, + self.kv_a_layernorm.weight, + cos, + sin, + self.w_kc, + k_cache, + v_cache, + slot_mapping, + quant_scale0=self.qkv_a_proj.input_scale, + quant_offset0=self.qkv_a_proj_input_offset, + bias0=self.qkv_a_proj_quant_bias_kvq, + quant_scale1=self.q_b_proj.input_scale, + quant_offset1=self.q_b_proj_input_offset, + bias1=self.q_b_proj_quant_bias, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + q_out0=q_nope_out, + kv_cache_out0=k_cache, + q_out1=q_rope_out, + kv_cache_out1=v_cache, + ) + return ( + q_rope_out, + v_cache, + q_nope_out, + k_cache, + forward_batch, + zero_allocator, + positions, + ) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 05f068557..f0e9e5a7b 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -782,27 +782,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern, - # and generalization to more scenarios will be supported in the future. - if query.shape[1] * query.shape[2] > 4096: - return self.forward_native(positions, query, key, offsets) - num_tokens = query.shape[0] - rotary_mode = "half" if self.is_neox_style else "interleave" + num_tokens, num_q_heads, _ = query.shape + num_k_heads = key.shape[1] + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + cos, sin = cos_sin.chunk(2, dim=-1) + # Reshape to [batchsize, head_dim, seq, rotary_dim] + cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2) + query_rot = query[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: query_pass = query[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :] - query_rot, key_rot = torch_npu.npu_mrope( - torch.add(positions, offsets) if offsets is not None else positions, - query_rot.reshape(num_tokens, -1), - key_rot.reshape(num_tokens, -1), - self.cos_sin_cache, - self.rotary_dim, - mrope_section=[0, 0, 0], - rotary_mode=rotary_mode, + query_rot = torch_npu.npu_interleave_rope( + query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim), + cos, + sin, + ) + key_rot = torch_npu.npu_interleave_rope( + key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim), + cos, + sin, ) query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim) key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 91d0aa1a8..0db0ca164 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -43,6 +43,10 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.amx_utils import PackWeightMethod +from sglang.srt.layers.attention.npu_ops.mla_preprocess import ( + NPUFusedMLAPreprocess, + is_mla_preprocess_enabled, +) from sglang.srt.layers.communicator import ( LayerCommunicator, LayerScatterModes, @@ -1177,6 +1181,12 @@ class DeepseekV2AttentionMLA(nn.Module): self.weight_block_size = ( self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size ) + self.is_mla_preprocess_enabled = is_mla_preprocess_enabled() + if self.is_mla_preprocess_enabled: + assert ( + quant_config.get_name() == "w8a8_int8" + ), "MLA Preprocess only works with W8A8Int8" + self.mla_preprocess = None def dispatch_attn_forward_method( self, forward_batch: ForwardBatch @@ -1263,9 +1273,28 @@ class DeepseekV2AttentionMLA(nn.Module): positions, hidden_states, forward_batch, zero_allocator ) elif attn_forward_method == AttnForwardMethod.MLA: - inner_state = self.forward_absorb_prepare( - positions, hidden_states, forward_batch, zero_allocator - ) + if not self.is_mla_preprocess_enabled: + inner_state = self.forward_absorb_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) + else: + # TODO(iforgetmyname): to be separated as a standalone func + if self.mla_preprocess is None: + self.mla_preprocess = NPUFusedMLAPreprocess( + self.fused_qkv_a_proj_with_mqa, + self.q_a_layernorm, + self.kv_a_layernorm, + self.q_b_proj, + self.w_kc, + self.rotary_emb, + self.layer_id, + self.num_local_heads, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + ) + inner_state = self.mla_preprocess.forward( + positions, hidden_states, forward_batch, zero_allocator + ) elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: inner_state = self.forward_absorb_fused_mla_rope_prepare( positions, hidden_states, forward_batch, zero_allocator diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d6c939227..0681bdfe2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -174,6 +174,8 @@ def is_blackwell(): @lru_cache(maxsize=1) def is_sm100_supported(device=None) -> bool: + if not is_cuda_alike(): + return False return (torch.cuda.get_device_capability(device)[0] == 10) and ( torch.version.cuda >= "12.8" ) @@ -181,6 +183,8 @@ def is_sm100_supported(device=None) -> bool: @lru_cache(maxsize=1) def is_sm90_supported(device=None) -> bool: + if not is_cuda_alike(): + return False return (torch.cuda.get_device_capability(device)[0] == 9) and ( torch.version.cuda >= "12.3" ) diff --git a/test/srt/ascend/test_ascend_deepep.py b/test/srt/ascend/test_ascend_deepep.py index 6ccd34d27..de51e35b3 100644 --- a/test/srt/ascend/test_ascend_deepep.py +++ b/test/srt/ascend/test_ascend_deepep.py @@ -60,6 +60,7 @@ class TestAscendDeepEP(CustomTestCase): cls.extra_envs = { "HCCL_BUFFSIZE": "500", "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "32", + "SGLANG_NPU_USE_MLAPO": "1", } os.environ.update(cls.extra_envs)