AITER backend extension and workload optimizations (#6838)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
2
.github/workflows/pr-test-amd.yml
vendored
2
.github/workflows/pr-test-amd.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
- name: Evaluate accuracy (TP=2)
|
||||
timeout-minutes: 30
|
||||
run: |
|
||||
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py
|
||||
bash scripts/amd_ci_exec.sh -e SGLANG_USE_AITER=0 python3 test_moe_eval_accuracy_large.py
|
||||
|
||||
mla-test-1-gpu-amd:
|
||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||
|
||||
@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its
|
||||
|
||||
| Environment Variable | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `SGLANG_AITER_MOE` | Use AITER MOE implementation | `false` |
|
||||
| `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` |
|
||||
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
|
||||
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
|
||||
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
|
||||
|
||||
@@ -27,12 +27,19 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
try:
|
||||
from aiter import mha_batch_prefill_func, paged_attention_ragged
|
||||
from aiter import (
|
||||
flash_attn_varlen_func,
|
||||
mha_batch_prefill_func,
|
||||
paged_attention_ragged,
|
||||
)
|
||||
from aiter.mla import mla_decode_fwd
|
||||
except ImportError:
|
||||
print(
|
||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||
)
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
SLIDING_WINDOW = auto()
|
||||
@@ -43,6 +50,10 @@ class WrapperDispatch(Enum):
|
||||
class ForwardMetadata:
|
||||
kv_indptr: torch.Tensor
|
||||
kv_indices: torch.Tensor
|
||||
qo_indptr: torch.Tensor
|
||||
kv_last_page_len: torch.Tensor
|
||||
max_extend_len: int
|
||||
max_prefix_extend_len: int
|
||||
max_q_len: int
|
||||
max_kv_len: int
|
||||
|
||||
@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
|
||||
self.device = model_runner.device
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend):
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
|
||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
|
||||
# Parse constants
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend):
|
||||
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
if self.use_mla:
|
||||
self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
|
||||
model_runner, self
|
||||
)
|
||||
|
||||
# aiter kernel related initialization
|
||||
self.max_num_partitions = (
|
||||
@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend):
|
||||
|
||||
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
||||
|
||||
self.workspace_buffer = torch.empty(
|
||||
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||
* nbyes_per_qo_elem
|
||||
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.use_mla:
|
||||
self.workspace_buffer = torch.empty(
|
||||
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||
* nbyes_per_qo_elem
|
||||
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.scale = float(1.0 / (self.head_dim**0.5))
|
||||
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
self.logits_soft_cap = 0.0
|
||||
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
|
||||
if self.use_mla:
|
||||
self.qo_indptr_ = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
qo_indptr = None
|
||||
kv_last_page_len = None
|
||||
max_extend_len = None
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
# update for aiter
|
||||
# create kv_indices and kv_inptr
|
||||
bs = forward_batch.batch_size
|
||||
kv_indptr = self.kv_indptr
|
||||
spec_info = forward_batch.spec_info
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend):
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
||||
if self.use_mla:
|
||||
qo_indptr = self.qo_indptr_[: bs + 1]
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
|
||||
kv_last_page_len = self.kv_last_page_len[:bs]
|
||||
max_extend_len = 1
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
kv_last_page_len,
|
||||
max_extend_len,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
self.mla_indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
prefix_lens.sum().item(),
|
||||
forward_batch.extend_seq_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.mla_indices_updater_prefill.kv_indptr,
|
||||
self.mla_indices_updater_prefill.kv_indices,
|
||||
self.mla_indices_updater_prefill.qo_indptr,
|
||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||
self.mla_indices_updater_prefill.max_extend_len,
|
||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
self.mla_indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
prefix_lens.sum().item(),
|
||||
forward_batch.extend_seq_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.mla_indices_updater_prefill.kv_indptr,
|
||||
self.mla_indices_updater_prefill.kv_indices,
|
||||
self.mla_indices_updater_prefill.qo_indptr,
|
||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||
self.mla_indices_updater_prefill.max_extend_len,
|
||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=forward_batch.spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend):
|
||||
else:
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
self.mla_indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
prefix_lens.sum().item(),
|
||||
forward_batch.extend_seq_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.mla_indices_updater_prefill.kv_indptr,
|
||||
self.mla_indices_updater_prefill.kv_indices,
|
||||
self.mla_indices_updater_prefill.qo_indptr,
|
||||
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||
self.mla_indices_updater_prefill.max_extend_len,
|
||||
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
qo_indptr = None
|
||||
kv_last_page_len = None
|
||||
max_extend_len = None
|
||||
|
||||
if spec_info is None:
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
@@ -255,24 +374,82 @@ class AiterAttnBackend(AttentionBackend):
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
||||
|
||||
if self.use_mla:
|
||||
qo_indptr = self.qo_indptr_[: bs + 1]
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(
|
||||
self.cuda_graph_kv_last_page_len[:bs], dim=0
|
||||
)
|
||||
max_extend_len = 1
|
||||
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
kv_last_page_len,
|
||||
max_extend_len,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
if self.use_mla:
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
max_extend_len = self.num_draft_tokens
|
||||
kv_last_page_len = None
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
qo_indptr,
|
||||
kv_last_page_len,
|
||||
max_extend_len,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
self.indices_updater_prefill.kv_indptr,
|
||||
self.indices_updater_prefill.kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
self.indices_updater_prefill.max_q_len,
|
||||
self.indices_updater_prefill.max_kv_len,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
if self.use_mla:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
else:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
if self.use_mla:
|
||||
max_extend_len = self.forward_metadata.max_extend_len
|
||||
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
|
||||
kv_indptr = self.forward_metadata.kv_indptr
|
||||
kv_indices = self.forward_metadata.kv_indices
|
||||
kv_last_page_lens = self.forward_metadata.kv_last_page_len
|
||||
qo_indptr = self.forward_metadata.qo_indptr
|
||||
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||
kv_lora_rank = V_Buffer.shape[-1]
|
||||
qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
|
||||
qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
|
||||
assert len(q.shape) == 3
|
||||
assert len(k.shape) == 3
|
||||
assert len(v.shape) == 3
|
||||
|
||||
if kv_indices.shape[0] == 0:
|
||||
o = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
max_extend_len,
|
||||
max_extend_len,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
)
|
||||
return o
|
||||
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
||||
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
||||
kvc, k_pe = torch.split(
|
||||
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
||||
)
|
||||
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
||||
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
kvprefix = kvprefix.view(
|
||||
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
||||
)
|
||||
k_prefix, v_prefix = torch.split(
|
||||
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
||||
)
|
||||
k_prefix = torch.cat(
|
||||
[
|
||||
k_prefix,
|
||||
torch.broadcast_to(
|
||||
k_pe,
|
||||
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
assert (
|
||||
forward_batch.extend_prefix_lens.shape
|
||||
== forward_batch.extend_seq_lens.shape
|
||||
)
|
||||
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
|
||||
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
|
||||
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
|
||||
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
|
||||
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
|
||||
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
|
||||
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
|
||||
|
||||
bs0 = forward_batch.batch_size + 1
|
||||
o = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
max_extend_len,
|
||||
max_prefix_extend_len,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
)
|
||||
return o
|
||||
else:
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||
layer.layer_id
|
||||
)
|
||||
|
||||
o = mha_batch_prefill_func(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.qo_indptr[:bs0],
|
||||
self.forward_metadata.kv_indptr[:bs0],
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.max_q_len,
|
||||
self.forward_metadata.max_kv_len,
|
||||
causal=True,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
alibi_slopes=None,
|
||||
return_lse=False,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
bs0 = forward_batch.batch_size + 1
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
o = mha_batch_prefill_func(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache,
|
||||
v_cache,
|
||||
self.qo_indptr[:bs0],
|
||||
self.forward_metadata.kv_indptr[:bs0],
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.max_q_len,
|
||||
self.forward_metadata.max_kv_len,
|
||||
causal=True,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
alibi_slopes=None,
|
||||
return_lse=False,
|
||||
return_attn_probs=False,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
|
||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend):
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
|
||||
self.logits_soft_cap = layer.logit_cap
|
||||
paged_attention_ragged(
|
||||
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
self.workspace_buffer,
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||
),
|
||||
self.scale,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.kv_last_page_lens,
|
||||
1,
|
||||
self.max_num_partitions,
|
||||
None,
|
||||
"auto",
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
self.k_scale,
|
||||
self.v_scale,
|
||||
None,
|
||||
_AITER_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
if self.use_mla:
|
||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
mla_decode_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
k_buffer.view(-1, 1, 1, layer.qk_head_dim),
|
||||
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||
self.forward_metadata.qo_indptr,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.forward_metadata.kv_last_page_len,
|
||||
self.forward_metadata.max_extend_len,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
)
|
||||
k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
|
||||
else:
|
||||
self.logits_soft_cap = layer.logit_cap
|
||||
paged_attention_ragged(
|
||||
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
self.workspace_buffer,
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||
),
|
||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||
),
|
||||
self.scale,
|
||||
self.forward_metadata.kv_indptr,
|
||||
self.forward_metadata.kv_indices,
|
||||
self.kv_last_page_len,
|
||||
1,
|
||||
self.max_num_partitions,
|
||||
None,
|
||||
"auto",
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
self.k_scale,
|
||||
self.v_scale,
|
||||
None,
|
||||
_AITER_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
@@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill:
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
None,
|
||||
paged_kernel_lens_sum,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
self.kv_indices = kv_indices
|
||||
|
||||
|
||||
class AiterMlaIndicesUpdaterPrefill:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Parse Constants
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# Buffers and wrappers
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
self.kv_indptr = None
|
||||
self.kv_indices = None
|
||||
self.qo_indptr = None
|
||||
self.kv_last_page_len = None
|
||||
self.max_extend_len = 0
|
||||
self.max_prefix_extend_len = 0
|
||||
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_sum: int,
|
||||
extend_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_sum: int,
|
||||
extend_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
|
||||
paged_kernel_lens = prefix_lens
|
||||
paged_kernel_lens_sum = prefix_lens_sum
|
||||
|
||||
bs = len(req_pool_indices)
|
||||
|
||||
kv_indptr = self.attn_backend.kv_indptr
|
||||
|
||||
if spec_info is None:
|
||||
# Normal extend
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum,
|
||||
dtype=torch.int32,
|
||||
device=req_pool_indices.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
qo_indptr = self.attn_backend.qo_indptr
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
|
||||
max_extend_len = torch.max(extend_lens).item()
|
||||
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
|
||||
kv_indptr += qo_indptr
|
||||
else:
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
self.kv_indptr = kv_indptr
|
||||
self.kv_indices = kv_indices
|
||||
self.qo_indptr = qo_indptr
|
||||
self.max_extend_len = max_extend_len
|
||||
self.max_prefix_extend_len = max_prefix_extend_len
|
||||
|
||||
@@ -20,10 +20,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -33,7 +34,10 @@ if _is_cuda:
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
from aiter import rmsnorm2d_fwd as rms_norm
|
||||
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
||||
elif _is_hip:
|
||||
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,6 +52,8 @@ class RMSNorm(CustomOp):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
if _use_aiter:
|
||||
self._forward_method = self.forward_aiter
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -60,6 +66,25 @@ class RMSNorm(CustomOp):
|
||||
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_aiter(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
residual_out = torch.empty_like(x)
|
||||
output = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
residual,
|
||||
residual_out,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -1332,7 +1332,7 @@ def fused_experts_impl(
|
||||
if (
|
||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
||||
or block_shape is not None
|
||||
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
|
||||
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
|
||||
):
|
||||
padded_size = 0
|
||||
|
||||
|
||||
@@ -28,8 +28,9 @@ else:
|
||||
import logging
|
||||
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_hip:
|
||||
if _use_aiter:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
if apply_router_weight_on_input:
|
||||
assert (
|
||||
|
||||
@@ -77,8 +77,8 @@ _is_cuda = is_cuda()
|
||||
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_hip:
|
||||
from aiter import ActivationType, QuantType
|
||||
@@ -487,7 +487,7 @@ class Fp8MoEMethod:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
|
||||
params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.block_quant:
|
||||
block_n, block_k = (
|
||||
@@ -512,7 +512,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
if _is_hip and use_hip_int4:
|
||||
if _is_hip and _use_hip_int4:
|
||||
# INT4 MoE weight - INT32 packed
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -641,7 +641,7 @@ class Fp8MoEMethod:
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
|
||||
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
|
||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||
w13_weight_scale1 = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||
@@ -668,7 +668,7 @@ class Fp8MoEMethod:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
if _is_hip and use_hip_int4:
|
||||
if _is_hip and _use_hip_int4:
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
@@ -700,7 +700,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if _is_hip and use_hip_int4:
|
||||
if _is_hip and _use_hip_int4:
|
||||
self.process_weights_hip_int4(layer)
|
||||
return
|
||||
|
||||
@@ -731,7 +731,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
layer.w2_input_scale = None
|
||||
|
||||
if _is_hip and use_aiter_moe:
|
||||
if _use_aiter:
|
||||
# Pre-shuffle weights
|
||||
layer.w13_weight.data = shuffle_weight(
|
||||
layer.w13_weight.contiguous(), (16, 16)
|
||||
@@ -853,7 +853,7 @@ class Fp8MoEMethod:
|
||||
return
|
||||
|
||||
def process_weights_hip_int4(self, layer: Module):
|
||||
# TODO: and use_aiter_moe: add after triton kernel added
|
||||
# TODO: _use_aiter: add after triton kernel added
|
||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||
# Weight Permutation
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
@@ -900,7 +900,7 @@ class Fp8MoEMethod:
|
||||
padding_size, # Avoid circular import
|
||||
)
|
||||
|
||||
if use_aiter_moe:
|
||||
if _use_aiter:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
@@ -911,7 +911,7 @@ class Fp8MoEMethod:
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
# ROCm (use_aiter_moe): using column-wise scaling
|
||||
# ROCm (_use_aiter): using column-wise scaling
|
||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
||||
@@ -1041,8 +1041,8 @@ class Fp8MoEMethod:
|
||||
activation: str = "silu",
|
||||
no_combine: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if use_hip_int4:
|
||||
# TODO: add triton kernel and add check use_aiter_moe
|
||||
if _use_hip_int4:
|
||||
# TODO: add triton kernel and add check _use_aiter
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
@@ -1058,13 +1058,13 @@ class Fp8MoEMethod:
|
||||
),
|
||||
)
|
||||
|
||||
if use_aiter_moe:
|
||||
if _use_aiter:
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
|
||||
# TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
|
||||
), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
@@ -38,11 +38,10 @@ _is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
||||
|
||||
if _is_hip and use_aiter_moe:
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
if _use_aiter:
|
||||
from aiter import gemm_a8w8_blockscale_CK
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
||||
@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
|
||||
return flashinfer_gemm_w8a8_block_fp8_linear
|
||||
elif CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||
return cutlass_w8a8_block_fp8_linear_with_fallback
|
||||
elif _is_hip and use_aiter_moe:
|
||||
elif _use_aiter:
|
||||
return aiter_w8a8_block_fp8_linear
|
||||
elif _ENABLE_JIT_DEEPGEMM:
|
||||
return deepgemm_w8a8_block_fp8_linear_with_fallback
|
||||
@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear(
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
output = torch.zeros(
|
||||
[q_input.shape[0], weight.shape[0]],
|
||||
dtype=input_2d.dtype,
|
||||
device=q_input.device,
|
||||
output = gemm_a8w8_blockscale_CK(
|
||||
q_input, weight, x_scale, weight_scale, dtype=input.dtype
|
||||
)
|
||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||
|
||||
if bias is not None:
|
||||
output += bias
|
||||
|
||||
@@ -355,6 +355,15 @@ class ModelRunner:
|
||||
# MLA architecture
|
||||
if is_hopper_with_cuda_12_3():
|
||||
server_args.attention_backend = "fa3"
|
||||
elif _is_hip:
|
||||
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||
# TODO current aiter only support head number 16 or 128 head number
|
||||
if (
|
||||
head_num == 128 or head_num == 16
|
||||
) and self.spec_algorithm.is_none():
|
||||
server_args.attention_backend = "aiter"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
logger.info(
|
||||
@@ -363,6 +372,7 @@ class ModelRunner:
|
||||
elif self.use_mla_backend:
|
||||
if server_args.device != "cpu":
|
||||
if server_args.attention_backend in [
|
||||
"aiter",
|
||||
"flashinfer",
|
||||
"fa3",
|
||||
"triton",
|
||||
|
||||
@@ -105,6 +105,7 @@ from sglang.srt.utils import (
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
@@ -120,6 +121,9 @@ if _is_hip:
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
from aiter.rotary_embedding import get_rope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
|
||||
self.alt_stream = alt_stream
|
||||
self.attn_mha.kv_b_proj = None
|
||||
|
||||
self.w_kc = None
|
||||
self.w_vc = None
|
||||
@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif self.attention_backend == "aiter":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
if self.attn_mha.kv_b_proj is None:
|
||||
self.attn_mha.kv_b_proj = self.kv_b_proj
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Default working directory
|
||||
WORKDIR="/sglang-checkout/test/srt"
|
||||
ENV_ARGS=(
|
||||
-e SGLANG_AMD_CI=1
|
||||
-e SGLANG_IS_IN_CI=1
|
||||
-e SGLANG_AITER_MOE=1
|
||||
declare -A ENV_MAP=(
|
||||
[SGLANG_AMD_CI]=1
|
||||
[SGLANG_IS_IN_CI]=1
|
||||
[SGLANG_USE_AITER]=1
|
||||
)
|
||||
|
||||
# Parse optional -w/--workdir and -e ENV=VAL flags
|
||||
# Parse -w/--workdir and -e ENV=VAL
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
-w|--workdir)
|
||||
@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do
|
||||
shift 2
|
||||
;;
|
||||
-e)
|
||||
ENV_ARGS+=("-e" "$2")
|
||||
IFS="=" read -r key val <<< "$2"
|
||||
ENV_MAP["$key"]="$val"
|
||||
shift 2
|
||||
;;
|
||||
--)
|
||||
@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do
|
||||
esac
|
||||
done
|
||||
|
||||
# Build final ENV_ARGS
|
||||
ENV_ARGS=()
|
||||
for key in "${!ENV_MAP[@]}"; do
|
||||
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
|
||||
done
|
||||
|
||||
# Run docker exec
|
||||
docker exec \
|
||||
-w "$WORKDIR" \
|
||||
|
||||
@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase):
|
||||
os.environ["HF_HUB_DISABLE_XET"] = (
|
||||
"1" if model in DISABLE_HF_XET_MODELS else "0"
|
||||
)
|
||||
os.environ["SGLANG_AITER_MOE"] = (
|
||||
os.environ["SGLANG_USE_AITER"] = (
|
||||
"0" if model in TRITON_MOE_MODELS else "1"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user