diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index d5e0a13e2..daa4176fc 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -72,7 +72,7 @@ jobs: - name: Evaluate accuracy (TP=2) timeout-minutes: 30 run: | - bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py + bash scripts/amd_ci_exec.sh -e SGLANG_USE_AITER=0 python3 test_moe_eval_accuracy_large.py mla-test-1-gpu-amd: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 27b8b12fc..931fdbadb 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its | Environment Variable | Description | Default Value | | --- | --- | --- | -| `SGLANG_AITER_MOE` | Use AITER MOE implementation | `false` | +| `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` | | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 8be35ec89..49200c071 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -27,12 +27,19 @@ if TYPE_CHECKING: from sglang.srt.speculative.spec_info import SpecInfo try: - from aiter import mha_batch_prefill_func, paged_attention_ragged + from aiter import ( + flash_attn_varlen_func, + mha_batch_prefill_func, + paged_attention_ragged, + ) + from aiter.mla import mla_decode_fwd except ImportError: print( "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." ) +from sglang.srt.configs.model_config import AttentionArch + class WrapperDispatch(Enum): SLIDING_WINDOW = auto() @@ -43,6 +50,10 @@ class WrapperDispatch(Enum): class ForwardMetadata: kv_indptr: torch.Tensor kv_indices: torch.Tensor + qo_indptr: torch.Tensor + kv_last_page_len: torch.Tensor + max_extend_len: int + max_prefix_extend_len: int max_q_len: int max_kv_len: int @@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend): self.device = model_runner.device self.is_multimodal = model_runner.model_config.is_multimodal + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) @@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend): self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA + # Parse constants self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill @@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend): self.indices_updater_prefill = AiterIndicesUpdaterPrefill( model_runner, self ) + if self.use_mla: + self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill( + model_runner, self + ) # aiter kernel related initialization self.max_num_partitions = ( @@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend): nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 - self.workspace_buffer = torch.empty( - (max_bs * self.num_head * self.max_num_partitions * self.head_dim) - * nbyes_per_qo_elem - + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, - dtype=torch.uint8, - device=self.device, - ) + if not self.use_mla: + self.workspace_buffer = torch.empty( + (max_bs * self.num_head * self.max_num_partitions * self.head_dim) + * nbyes_per_qo_elem + + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, + dtype=torch.uint8, + device=self.device, + ) self.scale = float(1.0 / (self.head_dim**0.5)) self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( self.device ) - self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to( - self.device - ) self.logits_soft_cap = 0.0 self.forward_metadata: ForwardMetadata = None + if self.use_mla: + self.qo_indptr_ = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init auxiliary variables for triton attention backend.""" + + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + qo_indptr = None + kv_last_page_len = None + max_extend_len = None + if forward_batch.forward_mode.is_decode_or_idle(): - # update for aiter - # create kv_indices and kv_inptr - bs = forward_batch.batch_size - kv_indptr = self.kv_indptr - spec_info = forward_batch.spec_info if spec_info is None: kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] @@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend): kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None) + if self.use_mla: + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0) + kv_last_page_len = self.kv_last_page_len[:bs] + max_extend_len = 1 + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_extend_len, + None, + None, + None, + ) elif forward_batch.forward_mode.is_draft_extend(): - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens=None, - encoder_lens=forward_batch.encoder_lens, - spec_info=forward_batch.spec_info, - ) - self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, - ) + if self.use_mla: + prefix_lens = forward_batch.extend_prefix_lens + self.mla_indices_updater_prefill.update( + forward_batch.req_pool_indices, + prefix_lens, + prefix_lens.sum().item(), + forward_batch.extend_seq_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + self.forward_metadata = ForwardMetadata( + self.mla_indices_updater_prefill.kv_indptr, + self.mla_indices_updater_prefill.kv_indices, + self.mla_indices_updater_prefill.qo_indptr, + self.mla_indices_updater_prefill.kv_last_page_len, + self.mla_indices_updater_prefill.max_extend_len, + self.mla_indices_updater_prefill.max_prefix_extend_len, + None, + None, + ) + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + None, + None, + None, + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) elif forward_batch.forward_mode.is_target_verify(): - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens=None, - encoder_lens=forward_batch.encoder_lens, - spec_info=forward_batch.spec_info, - ) - self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, - ) + if self.use_mla: + prefix_lens = forward_batch.extend_prefix_lens + self.mla_indices_updater_prefill.update( + forward_batch.req_pool_indices, + prefix_lens, + prefix_lens.sum().item(), + forward_batch.extend_seq_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + self.forward_metadata = ForwardMetadata( + self.mla_indices_updater_prefill.kv_indptr, + self.mla_indices_updater_prefill.kv_indices, + self.mla_indices_updater_prefill.qo_indptr, + self.mla_indices_updater_prefill.kv_last_page_len, + self.mla_indices_updater_prefill.max_extend_len, + self.mla_indices_updater_prefill.max_prefix_extend_len, + None, + None, + ) + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + None, + None, + None, + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) else: prefix_lens = forward_batch.extend_prefix_lens @@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend): else: extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens, - encoder_lens=forward_batch.encoder_lens, - spec_info=None, - ) - self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, - ) + if self.use_mla: + self.mla_indices_updater_prefill.update( + forward_batch.req_pool_indices, + prefix_lens, + prefix_lens.sum().item(), + forward_batch.extend_seq_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + self.forward_metadata = ForwardMetadata( + self.mla_indices_updater_prefill.kv_indptr, + self.mla_indices_updater_prefill.kv_indices, + self.mla_indices_updater_prefill.qo_indptr, + self.mla_indices_updater_prefill.kv_last_page_len, + self.mla_indices_updater_prefill.max_extend_len, + self.mla_indices_updater_prefill.max_prefix_extend_len, + None, + None, + ) + else: + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + None, + None, + None, + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) def init_cuda_graph_state( self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None ): + self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_bs * self.max_context_len), @@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend): spec_info: Optional[SpecInfo], ): if forward_mode.is_decode_or_idle(): + qo_indptr = None + kv_last_page_len = None + max_extend_len = None + if spec_info is None: kv_indptr = self.kv_indptr kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) @@ -255,24 +374,82 @@ class AiterAttnBackend(AttentionBackend): ) else: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices - self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None) + + if self.use_mla: + qo_indptr = self.qo_indptr_[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum( + self.cuda_graph_kv_last_page_len[:bs], dim=0 + ) + max_extend_len = 1 + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_extend_len, + None, + None, + None, + ) elif forward_mode.is_target_verify(): - seq_lens_sum = seq_lens.sum().item() - self.indices_updater_prefill.update( - req_pool_indices, - seq_lens, - seq_lens_sum, - prefix_lens=None, - encoder_lens=encoder_lens, - spec_info=spec_info, - ) - self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, - ) + if self.use_mla: + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * self.num_draft_tokens, + step=self.num_draft_tokens, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + max_extend_len = self.num_draft_tokens + kv_last_page_len = None + + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + kv_last_page_len, + max_extend_len, + None, + None, + None, + ) + else: + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + encoder_lens=encoder_lens, + spec_info=spec_info, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + None, + None, + None, + None, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend): if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + if self.use_mla: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + if self.use_mla: + max_extend_len = self.forward_metadata.max_extend_len + max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len + kv_indptr = self.forward_metadata.kv_indptr + kv_indices = self.forward_metadata.kv_indices + kv_last_page_lens = self.forward_metadata.kv_last_page_len + qo_indptr = self.forward_metadata.qo_indptr + K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + kv_lora_rank = V_Buffer.shape[-1] + qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank + qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim + assert len(q.shape) == 3 + assert len(k.shape) == 3 + assert len(v.shape) == 3 + + if kv_indices.shape[0] == 0: + o = flash_attn_varlen_func( + q, + k, + v, + qo_indptr, + qo_indptr, + max_extend_len, + max_extend_len, + softmax_scale=layer.scaling, + causal=True, ) + return o + elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim): + K_Buffer = torch.index_select(K_Buffer, 0, kv_indices) + kvc, k_pe = torch.split( + K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1 + ) + kvprefix = layer.kv_b_proj(kvc.contiguous())[0] - k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + kvprefix = kvprefix.view( + -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim + ) + k_prefix, v_prefix = torch.split( + kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1 + ) + k_prefix = torch.cat( + [ + k_prefix, + torch.broadcast_to( + k_pe, + (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]), + ), + ], + dim=-1, + ) + assert ( + forward_batch.extend_prefix_lens.shape + == forward_batch.extend_seq_lens.shape + ) + k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu) + k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu) + assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu) + k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el]) + v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu) + v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu) + v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el]) - bs0 = forward_batch.batch_size + 1 + o = flash_attn_varlen_func( + q, + k, + v, + qo_indptr, + kv_indptr, + max_extend_len, + max_prefix_extend_len, + softmax_scale=layer.scaling, + causal=True, + ) + return o + else: + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) - o = mha_batch_prefill_func( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - k_cache, - v_cache, - self.qo_indptr[:bs0], - self.forward_metadata.kv_indptr[:bs0], - self.forward_metadata.kv_indices, - self.forward_metadata.max_q_len, - self.forward_metadata.max_kv_len, - causal=True, - logits_soft_cap=self.logits_soft_cap, - alibi_slopes=None, - return_lse=False, - return_attn_probs=False, - ) + bs0 = forward_batch.batch_size + 1 - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + o = mha_batch_prefill_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache, + v_cache, + self.qo_indptr[:bs0], + self.forward_metadata.kv_indptr[:bs0], + self.forward_metadata.kv_indices, + self.forward_metadata.max_q_len, + self.forward_metadata.max_kv_len, + causal=True, + logits_soft_cap=self.logits_soft_cap, + alibi_slopes=None, + return_lse=False, + return_attn_probs=False, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( self, @@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) if layer.qk_head_dim != layer.v_head_dim: @@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) - self.logits_soft_cap = layer.logit_cap - paged_attention_ragged( - o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - self.workspace_buffer, - q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), - forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( - -1, 1, layer.tp_k_head_num, layer.qk_head_dim - ), - forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( - -1, 1, layer.tp_v_head_num, layer.v_head_dim - ), - self.scale, - self.forward_metadata.kv_indptr, - self.forward_metadata.kv_indices, - self.kv_last_page_lens, - 1, - self.max_num_partitions, - None, - "auto", - "NHD", - self.logits_soft_cap, - self.k_scale, - self.v_scale, - None, - _AITER_PARTITION_SIZE_ROCM, - ) + if self.use_mla: + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + mla_decode_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k_buffer.view(-1, 1, 1, layer.qk_head_dim), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.kv_last_page_len, + self.forward_metadata.max_extend_len, + layer.scaling, + layer.logit_cap, + ) + k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim) + else: + self.logits_soft_cap = layer.logit_cap + paged_attention_ragged( + o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + self.workspace_buffer, + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( + -1, 1, layer.tp_k_head_num, layer.qk_head_dim + ), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( + -1, 1, layer.tp_v_head_num, layer.v_head_dim + ), + self.scale, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.kv_last_page_len, + 1, + self.max_num_partitions, + None, + "auto", + "NHD", + self.logits_soft_cap, + self.k_scale, + self.v_scale, + None, + _AITER_PARTITION_SIZE_ROCM, + ) return o @@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill: spec_info.generate_attn_arg_prefill( req_pool_indices, paged_kernel_lens, - None, + paged_kernel_lens_sum, self.req_to_token, ) ) self.kv_indices = kv_indices + + +class AiterMlaIndicesUpdaterPrefill: + def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): + # Parse Constants + self.attn_backend = attn_backend + + # Buffers and wrappers + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.update = self.update_single_wrapper + + self.kv_indptr = None + self.kv_indices = None + self.qo_indptr = None + self.kv_last_page_len = None + self.max_extend_len = 0 + self.max_prefix_extend_len = 0 + + def update( + self, + req_pool_indices: torch.Tensor, + prefix_lens: torch.Tensor, + prefix_lens_sum: int, + extend_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], + ): + # Keep the signature for type checking. It will be assigned during runtime. + raise NotImplementedError() + + def update_single_wrapper( + self, + req_pool_indices: torch.Tensor, + prefix_lens: torch.Tensor, + prefix_lens_sum: int, + extend_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], + ): + + paged_kernel_lens = prefix_lens + paged_kernel_lens_sum = prefix_lens_sum + + bs = len(req_pool_indices) + + kv_indptr = self.attn_backend.kv_indptr + + if spec_info is None: + # Normal extend + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum, + dtype=torch.int32, + device=req_pool_indices.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.attn_backend.qo_indptr + qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + + max_extend_len = torch.max(extend_lens).item() + max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item() + kv_indptr += qo_indptr + else: + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + self.req_to_token, + ) + ) + + self.kv_indptr = kv_indptr + self.kv_indices = kv_indices + self.qo_indptr = qo_indptr + self.max_extend_len = max_extend_len + self.max_prefix_extend_len = max_prefix_extend_len diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 1f398549b..0994a511e 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -20,10 +20,11 @@ import torch import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip _is_cuda = is_cuda() _is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import ( @@ -33,7 +34,10 @@ if _is_cuda: rmsnorm, ) -if _is_hip: +if _use_aiter: + from aiter import rmsnorm2d_fwd as rms_norm + from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm +elif _is_hip: from vllm._custom_ops import fused_add_rms_norm, rms_norm logger = logging.getLogger(__name__) @@ -48,6 +52,8 @@ class RMSNorm(CustomOp): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + if _use_aiter: + self._forward_method = self.forward_aiter def forward_cuda( self, @@ -60,6 +66,25 @@ class RMSNorm(CustomOp): out = rmsnorm(x, self.weight.data, self.variance_epsilon) return out + def forward_aiter( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + residual_out = torch.empty_like(x) + output = torch.empty_like(x) + fused_add_rms_norm( + output, + x, + residual, + residual_out, + self.weight.data, + self.variance_epsilon, + ) + return output, residual_out + return rms_norm(x, self.weight.data, self.variance_epsilon) + def forward_hip( self, x: torch.Tensor, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index c0fb09fec..bd1432d38 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1332,7 +1332,7 @@ def fused_experts_impl( if ( not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None - or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE")) + or (_is_hip and get_bool_env_var("SGLANG_USE_AITER")) ): padded_size = 0 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 34be9ca12..981780f42 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -28,8 +28,9 @@ else: import logging _is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip -if _is_hip: +if _use_aiter: from aiter import ActivationType from aiter.fused_moe_bf16_asm import ck_moe_2stages from aiter.ops.shuffle import shuffle_weight @@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): + if _use_aiter: layer.w13_weight = torch.nn.Parameter( shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, @@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): routed_scaling_factor=routed_scaling_factor, ) - if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): + if _use_aiter: assert not no_combine, "unsupported" if apply_router_weight_on_input: assert ( diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 3d8269a63..7847326a6 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -77,8 +77,8 @@ _is_cuda = is_cuda() _is_fp8_fnuz = is_fp8_fnuz() -use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") -use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") +_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_hip: from aiter import ActivationType, QuantType @@ -487,7 +487,7 @@ class Fp8MoEMethod: from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn + params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn tp_size = get_tensor_model_parallel_world_size() if self.block_quant: block_n, block_k = ( @@ -512,7 +512,7 @@ class Fp8MoEMethod: ) # WEIGHTS - if _is_hip and use_hip_int4: + if _is_hip and _use_hip_int4: # INT4 MoE weight - INT32 packed w13_weight = torch.nn.Parameter( torch.empty( @@ -641,7 +641,7 @@ class Fp8MoEMethod: layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel + if _is_hip: # _use_aiter: TODO: add check back after triton kernel # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling w13_weight_scale1 = torch.nn.Parameter( torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), @@ -668,7 +668,7 @@ class Fp8MoEMethod: set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - if _is_hip and use_hip_int4: + if _is_hip and _use_hip_int4: extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) @@ -700,7 +700,7 @@ class Fp8MoEMethod: layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if _is_hip and use_hip_int4: + if _is_hip and _use_hip_int4: self.process_weights_hip_int4(layer) return @@ -731,7 +731,7 @@ class Fp8MoEMethod: ) layer.w2_input_scale = None - if _is_hip and use_aiter_moe: + if _use_aiter: # Pre-shuffle weights layer.w13_weight.data = shuffle_weight( layer.w13_weight.contiguous(), (16, 16) @@ -853,7 +853,7 @@ class Fp8MoEMethod: return def process_weights_hip_int4(self, layer: Module): - # TODO: and use_aiter_moe: add after triton kernel added + # TODO: _use_aiter: add after triton kernel added # INT4-FP8 (INT4 MoE Weight, FP8 Compute) # Weight Permutation layer.w13_weight = torch.nn.Parameter( @@ -900,7 +900,7 @@ class Fp8MoEMethod: padding_size, # Avoid circular import ) - if use_aiter_moe: + if _use_aiter: layer.w13_weight = torch.nn.Parameter( shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, @@ -911,7 +911,7 @@ class Fp8MoEMethod: requires_grad=False, ) torch.cuda.empty_cache() - # ROCm (use_aiter_moe): using column-wise scaling + # ROCm (_use_aiter): using column-wise scaling layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) elif get_bool_env_var("SGLANG_MOE_PADDING"): @@ -1041,8 +1041,8 @@ class Fp8MoEMethod: activation: str = "silu", no_combine: bool = False, ) -> Optional[torch.Tensor]: - if use_hip_int4: - # TODO: add triton kernel and add check use_aiter_moe + if _use_hip_int4: + # TODO: add triton kernel and add check _use_aiter assert not no_combine, f"{no_combine=} is not supported." return ck_moe_2stages( x, @@ -1058,13 +1058,13 @@ class Fp8MoEMethod: ), ) - if use_aiter_moe: + if _use_aiter: assert not no_combine, f"{no_combine=} is not supported." if self.block_quant: - # TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being. + # TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being. assert ( activation == "silu" - ), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe" + ), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter" return asm_moe( x, layer.w13_weight, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index e105e50c3..e9e663e59 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -38,11 +38,10 @@ _is_hip = is_hip() _is_cuda = is_cuda() _is_fp8_fnuz = is_fp8_fnuz() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip -use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") - -if _is_hip and use_aiter_moe: - from aiter import gemm_a8w8_blockscale +if _use_aiter: + from aiter import gemm_a8w8_blockscale_CK if _is_cuda: from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm @@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable: return flashinfer_gemm_w8a8_block_fp8_linear elif CUTLASS_BLOCK_FP8_SUPPORTED: return cutlass_w8a8_block_fp8_linear_with_fallback - elif _is_hip and use_aiter_moe: + elif _use_aiter: return aiter_w8a8_block_fp8_linear elif _ENABLE_JIT_DEEPGEMM: return deepgemm_w8a8_block_fp8_linear_with_fallback @@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear( q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=False ) - output = torch.zeros( - [q_input.shape[0], weight.shape[0]], - dtype=input_2d.dtype, - device=q_input.device, + output = gemm_a8w8_blockscale_CK( + q_input, weight, x_scale, weight_scale, dtype=input.dtype ) - gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) if bias is not None: output += bias diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 30f6a7929..990335200 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -355,6 +355,15 @@ class ModelRunner: # MLA architecture if is_hopper_with_cuda_12_3(): server_args.attention_backend = "fa3" + elif _is_hip: + head_num = self.model_config.get_num_kv_heads(self.tp_size) + # TODO current aiter only support head number 16 or 128 head number + if ( + head_num == 128 or head_num == 16 + ) and self.spec_algorithm.is_none(): + server_args.attention_backend = "aiter" + else: + server_args.attention_backend = "triton" else: server_args.attention_backend = "triton" logger.info( @@ -363,6 +372,7 @@ class ModelRunner: elif self.use_mla_backend: if server_args.device != "cpu": if server_args.attention_backend in [ + "aiter", "flashinfer", "fa3", "triton", diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 48cbceb2f..43c711cbd 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -105,6 +105,7 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() _is_fp8_fnuz = is_fp8_fnuz() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 @@ -120,6 +121,9 @@ if _is_hip: decode_attention_fwd_grouped_rope, ) +if _use_aiter: + from aiter.rotary_embedding import get_rope + logger = logging.getLogger(__name__) @@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module): ) self.alt_stream = alt_stream + self.attn_mha.kv_b_proj = None self.w_kc = None self.w_vc = None @@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module): return AttnForwardMethod.MHA_CHUNKED_KV else: return _dispatch_mla_subtype() + elif self.attention_backend == "aiter": + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA else: # Triton: Use normal computation for prefill and use weight absorption for extend/decode if ( @@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch: ForwardBatch, zero_allocator: BumpAllocator, ): + if self.attn_mha.kv_b_proj is None: + self.attn_mha.kv_b_proj = self.kv_b_proj + if hidden_states.shape[0] == 0: assert ( not self.o_proj.reduce_results diff --git a/scripts/amd_ci_exec.sh b/scripts/amd_ci_exec.sh index a57e608e6..411fe2a75 100755 --- a/scripts/amd_ci_exec.sh +++ b/scripts/amd_ci_exec.sh @@ -1,15 +1,14 @@ #!/bin/bash set -euo pipefail -# Default working directory WORKDIR="/sglang-checkout/test/srt" -ENV_ARGS=( - -e SGLANG_AMD_CI=1 - -e SGLANG_IS_IN_CI=1 - -e SGLANG_AITER_MOE=1 +declare -A ENV_MAP=( + [SGLANG_AMD_CI]=1 + [SGLANG_IS_IN_CI]=1 + [SGLANG_USE_AITER]=1 ) -# Parse optional -w/--workdir and -e ENV=VAL flags +# Parse -w/--workdir and -e ENV=VAL while [[ $# -gt 0 ]]; do case "$1" in -w|--workdir) @@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do shift 2 ;; -e) - ENV_ARGS+=("-e" "$2") + IFS="=" read -r key val <<< "$2" + ENV_MAP["$key"]="$val" shift 2 ;; --) @@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do esac done +# Build final ENV_ARGS +ENV_ARGS=() +for key in "${!ENV_MAP[@]}"; do + ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}") +done + # Run docker exec docker exec \ -w "$WORKDIR" \ diff --git a/test/srt/test_nightly_gsm8k_eval_amd.py b/test/srt/test_nightly_gsm8k_eval_amd.py index e49cd107b..d726a8678 100644 --- a/test/srt/test_nightly_gsm8k_eval_amd.py +++ b/test/srt/test_nightly_gsm8k_eval_amd.py @@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase): os.environ["HF_HUB_DISABLE_XET"] = ( "1" if model in DISABLE_HF_XET_MODELS else "0" ) - os.environ["SGLANG_AITER_MOE"] = ( + os.environ["SGLANG_USE_AITER"] = ( "0" if model in TRITON_MOE_MODELS else "1" )