AITER backend extension and workload optimizations (#6838)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: Hubert Lu <Hubert.Lu@amd.com>
This commit is contained in:
2
.github/workflows/pr-test-amd.yml
vendored
2
.github/workflows/pr-test-amd.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
|||||||
- name: Evaluate accuracy (TP=2)
|
- name: Evaluate accuracy (TP=2)
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py
|
bash scripts/amd_ci_exec.sh -e SGLANG_USE_AITER=0 python3 test_moe_eval_accuracy_large.py
|
||||||
|
|
||||||
mla-test-1-gpu-amd:
|
mla-test-1-gpu-amd:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its
|
|||||||
|
|
||||||
| Environment Variable | Description | Default Value |
|
| Environment Variable | Description | Default Value |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `SGLANG_AITER_MOE` | Use AITER MOE implementation | `false` |
|
| `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` |
|
||||||
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
|
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
|
||||||
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
|
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
|
||||||
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
|
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
|
||||||
|
|||||||
@@ -27,12 +27,19 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
from sglang.srt.speculative.spec_info import SpecInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from aiter import mha_batch_prefill_func, paged_attention_ragged
|
from aiter import (
|
||||||
|
flash_attn_varlen_func,
|
||||||
|
mha_batch_prefill_func,
|
||||||
|
paged_attention_ragged,
|
||||||
|
)
|
||||||
|
from aiter.mla import mla_decode_fwd
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(
|
print(
|
||||||
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import AttentionArch
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
SLIDING_WINDOW = auto()
|
SLIDING_WINDOW = auto()
|
||||||
@@ -43,6 +50,10 @@ class WrapperDispatch(Enum):
|
|||||||
class ForwardMetadata:
|
class ForwardMetadata:
|
||||||
kv_indptr: torch.Tensor
|
kv_indptr: torch.Tensor
|
||||||
kv_indices: torch.Tensor
|
kv_indices: torch.Tensor
|
||||||
|
qo_indptr: torch.Tensor
|
||||||
|
kv_last_page_len: torch.Tensor
|
||||||
|
max_extend_len: int
|
||||||
|
max_prefix_extend_len: int
|
||||||
max_q_len: int
|
max_q_len: int
|
||||||
max_kv_len: int
|
max_kv_len: int
|
||||||
|
|
||||||
@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||||
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
self.num_head = (
|
self.num_head = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
)
|
)
|
||||||
@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
|
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||||
|
|
||||||
# Parse constants
|
# Parse constants
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
|
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
|
||||||
model_runner, self
|
model_runner, self
|
||||||
)
|
)
|
||||||
|
if self.use_mla:
|
||||||
|
self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
|
||||||
|
model_runner, self
|
||||||
|
)
|
||||||
|
|
||||||
# aiter kernel related initialization
|
# aiter kernel related initialization
|
||||||
self.max_num_partitions = (
|
self.max_num_partitions = (
|
||||||
@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
|
||||||
|
|
||||||
self.workspace_buffer = torch.empty(
|
if not self.use_mla:
|
||||||
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
self.workspace_buffer = torch.empty(
|
||||||
* nbyes_per_qo_elem
|
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
|
||||||
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
* nbyes_per_qo_elem
|
||||||
dtype=torch.uint8,
|
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
|
||||||
device=self.device,
|
dtype=torch.uint8,
|
||||||
)
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
self.scale = float(1.0 / (self.head_dim**0.5))
|
self.scale = float(1.0 / (self.head_dim**0.5))
|
||||||
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
|
|
||||||
self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logits_soft_cap = 0.0
|
self.logits_soft_cap = 0.0
|
||||||
|
|
||||||
self.forward_metadata: ForwardMetadata = None
|
self.forward_metadata: ForwardMetadata = None
|
||||||
|
|
||||||
|
if self.use_mla:
|
||||||
|
self.qo_indptr_ = torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
|
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
kv_indptr = self.kv_indptr
|
||||||
|
spec_info = forward_batch.spec_info
|
||||||
|
qo_indptr = None
|
||||||
|
kv_last_page_len = None
|
||||||
|
max_extend_len = None
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
# update for aiter
|
|
||||||
# create kv_indices and kv_inptr
|
|
||||||
bs = forward_batch.batch_size
|
|
||||||
kv_indptr = self.kv_indptr
|
|
||||||
spec_info = forward_batch.spec_info
|
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
if self.use_mla:
|
||||||
|
qo_indptr = self.qo_indptr_[: bs + 1]
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
|
||||||
|
kv_last_page_len = self.kv_last_page_len[:bs]
|
||||||
|
max_extend_len = 1
|
||||||
|
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
kv_last_page_len,
|
||||||
|
max_extend_len,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
elif forward_batch.forward_mode.is_draft_extend():
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
self.indices_updater_prefill.update(
|
if self.use_mla:
|
||||||
forward_batch.req_pool_indices,
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
forward_batch.seq_lens,
|
self.mla_indices_updater_prefill.update(
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.req_pool_indices,
|
||||||
prefix_lens=None,
|
prefix_lens,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
prefix_lens.sum().item(),
|
||||||
spec_info=forward_batch.spec_info,
|
forward_batch.extend_seq_lens,
|
||||||
)
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
self.forward_metadata = ForwardMetadata(
|
spec_info=None,
|
||||||
self.indices_updater_prefill.kv_indptr,
|
)
|
||||||
self.indices_updater_prefill.kv_indices,
|
self.forward_metadata = ForwardMetadata(
|
||||||
self.indices_updater_prefill.max_q_len,
|
self.mla_indices_updater_prefill.kv_indptr,
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.mla_indices_updater_prefill.kv_indices,
|
||||||
)
|
self.mla_indices_updater_prefill.qo_indptr,
|
||||||
|
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||||
|
self.mla_indices_updater_prefill.max_extend_len,
|
||||||
|
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
self.indices_updater_prefill.kv_indptr,
|
||||||
|
self.indices_updater_prefill.kv_indices,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
self.indices_updater_prefill.max_q_len,
|
||||||
|
self.indices_updater_prefill.max_kv_len,
|
||||||
|
)
|
||||||
elif forward_batch.forward_mode.is_target_verify():
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
self.indices_updater_prefill.update(
|
if self.use_mla:
|
||||||
forward_batch.req_pool_indices,
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
forward_batch.seq_lens,
|
self.mla_indices_updater_prefill.update(
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.req_pool_indices,
|
||||||
prefix_lens=None,
|
prefix_lens,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
prefix_lens.sum().item(),
|
||||||
spec_info=forward_batch.spec_info,
|
forward_batch.extend_seq_lens,
|
||||||
)
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
self.forward_metadata = ForwardMetadata(
|
spec_info=None,
|
||||||
self.indices_updater_prefill.kv_indptr,
|
)
|
||||||
self.indices_updater_prefill.kv_indices,
|
self.forward_metadata = ForwardMetadata(
|
||||||
self.indices_updater_prefill.max_q_len,
|
self.mla_indices_updater_prefill.kv_indptr,
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.mla_indices_updater_prefill.kv_indices,
|
||||||
)
|
self.mla_indices_updater_prefill.qo_indptr,
|
||||||
|
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||||
|
self.mla_indices_updater_prefill.max_extend_len,
|
||||||
|
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=forward_batch.spec_info,
|
||||||
|
)
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
self.indices_updater_prefill.kv_indptr,
|
||||||
|
self.indices_updater_prefill.kv_indices,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
self.indices_updater_prefill.max_q_len,
|
||||||
|
self.indices_updater_prefill.max_kv_len,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
|
||||||
self.indices_updater_prefill.update(
|
if self.use_mla:
|
||||||
forward_batch.req_pool_indices,
|
self.mla_indices_updater_prefill.update(
|
||||||
forward_batch.seq_lens,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens_sum,
|
prefix_lens,
|
||||||
prefix_lens,
|
prefix_lens.sum().item(),
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
forward_batch.extend_seq_lens,
|
||||||
spec_info=None,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
)
|
spec_info=None,
|
||||||
self.forward_metadata = ForwardMetadata(
|
)
|
||||||
self.indices_updater_prefill.kv_indptr,
|
self.forward_metadata = ForwardMetadata(
|
||||||
self.indices_updater_prefill.kv_indices,
|
self.mla_indices_updater_prefill.kv_indptr,
|
||||||
self.indices_updater_prefill.max_q_len,
|
self.mla_indices_updater_prefill.kv_indices,
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.mla_indices_updater_prefill.qo_indptr,
|
||||||
)
|
self.mla_indices_updater_prefill.kv_last_page_len,
|
||||||
|
self.mla_indices_updater_prefill.max_extend_len,
|
||||||
|
self.mla_indices_updater_prefill.max_prefix_extend_len,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_sum,
|
||||||
|
prefix_lens,
|
||||||
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
|
spec_info=None,
|
||||||
|
)
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
self.indices_updater_prefill.kv_indptr,
|
||||||
|
self.indices_updater_prefill.kv_indices,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
self.indices_updater_prefill.max_q_len,
|
||||||
|
self.indices_updater_prefill.max_kv_len,
|
||||||
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||||
):
|
):
|
||||||
|
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
(max_bs * self.max_context_len),
|
(max_bs * self.max_context_len),
|
||||||
@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
|
qo_indptr = None
|
||||||
|
kv_last_page_len = None
|
||||||
|
max_extend_len = None
|
||||||
|
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
kv_indptr = self.kv_indptr
|
kv_indptr = self.kv_indptr
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
@@ -255,24 +374,82 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
|
|
||||||
|
if self.use_mla:
|
||||||
|
qo_indptr = self.qo_indptr_[: bs + 1]
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(
|
||||||
|
self.cuda_graph_kv_last_page_len[:bs], dim=0
|
||||||
|
)
|
||||||
|
max_extend_len = 1
|
||||||
|
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
|
||||||
|
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
kv_last_page_len,
|
||||||
|
max_extend_len,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
if self.use_mla:
|
||||||
self.indices_updater_prefill.update(
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
req_pool_indices,
|
qo_indptr[: bs + 1] = torch.arange(
|
||||||
seq_lens,
|
0,
|
||||||
seq_lens_sum,
|
(1 + bs) * self.num_draft_tokens,
|
||||||
prefix_lens=None,
|
step=self.num_draft_tokens,
|
||||||
encoder_lens=encoder_lens,
|
dtype=torch.int32,
|
||||||
spec_info=spec_info,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.forward_metadata = ForwardMetadata(
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
self.indices_updater_prefill.kv_indptr,
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
self.indices_updater_prefill.kv_indices,
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
self.indices_updater_prefill.max_q_len,
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.indices_updater_prefill.max_kv_len,
|
self.req_to_token,
|
||||||
)
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
max_extend_len = self.num_draft_tokens
|
||||||
|
kv_last_page_len = None
|
||||||
|
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
kv_last_page_len,
|
||||||
|
max_extend_len,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
|
self.indices_updater_prefill.update(
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
prefix_lens=None,
|
||||||
|
encoder_lens=encoder_lens,
|
||||||
|
spec_info=spec_info,
|
||||||
|
)
|
||||||
|
self.forward_metadata = ForwardMetadata(
|
||||||
|
self.indices_updater_prefill.kv_indptr,
|
||||||
|
self.indices_updater_prefill.kv_indices,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
self.indices_updater_prefill.max_q_len,
|
||||||
|
self.indices_updater_prefill.max_kv_len,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||||
@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
if self.use_mla:
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
|
else:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_mla:
|
||||||
|
max_extend_len = self.forward_metadata.max_extend_len
|
||||||
|
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
|
||||||
|
kv_indptr = self.forward_metadata.kv_indptr
|
||||||
|
kv_indices = self.forward_metadata.kv_indices
|
||||||
|
kv_last_page_lens = self.forward_metadata.kv_last_page_len
|
||||||
|
qo_indptr = self.forward_metadata.qo_indptr
|
||||||
|
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
|
||||||
|
kv_lora_rank = V_Buffer.shape[-1]
|
||||||
|
qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
|
||||||
|
qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
|
||||||
|
assert len(q.shape) == 3
|
||||||
|
assert len(k.shape) == 3
|
||||||
|
assert len(v.shape) == 3
|
||||||
|
|
||||||
|
if kv_indices.shape[0] == 0:
|
||||||
|
o = flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
qo_indptr,
|
||||||
|
qo_indptr,
|
||||||
|
max_extend_len,
|
||||||
|
max_extend_len,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
|
return o
|
||||||
|
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
||||||
|
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
||||||
|
kvc, k_pe = torch.split(
|
||||||
|
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
||||||
|
|
||||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
kvprefix = kvprefix.view(
|
||||||
|
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
||||||
|
)
|
||||||
|
k_prefix, v_prefix = torch.split(
|
||||||
|
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
k_prefix = torch.cat(
|
||||||
|
[
|
||||||
|
k_prefix,
|
||||||
|
torch.broadcast_to(
|
||||||
|
k_pe,
|
||||||
|
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
forward_batch.extend_prefix_lens.shape
|
||||||
|
== forward_batch.extend_seq_lens.shape
|
||||||
|
)
|
||||||
|
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
|
||||||
|
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
|
||||||
|
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
|
||||||
|
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
|
||||||
|
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
|
||||||
|
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
|
||||||
|
|
||||||
bs0 = forward_batch.batch_size + 1
|
o = flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
max_extend_len,
|
||||||
|
max_prefix_extend_len,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
else:
|
||||||
|
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
|
||||||
|
layer.layer_id
|
||||||
|
)
|
||||||
|
|
||||||
o = mha_batch_prefill_func(
|
bs0 = forward_batch.batch_size + 1
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
self.qo_indptr[:bs0],
|
|
||||||
self.forward_metadata.kv_indptr[:bs0],
|
|
||||||
self.forward_metadata.kv_indices,
|
|
||||||
self.forward_metadata.max_q_len,
|
|
||||||
self.forward_metadata.max_kv_len,
|
|
||||||
causal=True,
|
|
||||||
logits_soft_cap=self.logits_soft_cap,
|
|
||||||
alibi_slopes=None,
|
|
||||||
return_lse=False,
|
|
||||||
return_attn_probs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
o = mha_batch_prefill_func(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
self.qo_indptr[:bs0],
|
||||||
|
self.forward_metadata.kv_indptr[:bs0],
|
||||||
|
self.forward_metadata.kv_indices,
|
||||||
|
self.forward_metadata.max_q_len,
|
||||||
|
self.forward_metadata.max_kv_len,
|
||||||
|
causal=True,
|
||||||
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
alibi_slopes=None,
|
||||||
|
return_lse=False,
|
||||||
|
return_attn_probs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||||
|
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logits_soft_cap = layer.logit_cap
|
if self.use_mla:
|
||||||
paged_attention_ragged(
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
mla_decode_fwd(
|
||||||
self.workspace_buffer,
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
k_buffer.view(-1, 1, 1, layer.qk_head_dim),
|
||||||
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
||||||
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
self.forward_metadata.qo_indptr,
|
||||||
),
|
self.forward_metadata.kv_indptr,
|
||||||
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
self.forward_metadata.kv_indices,
|
||||||
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
self.forward_metadata.kv_last_page_len,
|
||||||
),
|
self.forward_metadata.max_extend_len,
|
||||||
self.scale,
|
layer.scaling,
|
||||||
self.forward_metadata.kv_indptr,
|
layer.logit_cap,
|
||||||
self.forward_metadata.kv_indices,
|
)
|
||||||
self.kv_last_page_lens,
|
k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
|
||||||
1,
|
else:
|
||||||
self.max_num_partitions,
|
self.logits_soft_cap = layer.logit_cap
|
||||||
None,
|
paged_attention_ragged(
|
||||||
"auto",
|
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
"NHD",
|
self.workspace_buffer,
|
||||||
self.logits_soft_cap,
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
self.k_scale,
|
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
|
||||||
self.v_scale,
|
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
|
||||||
None,
|
),
|
||||||
_AITER_PARTITION_SIZE_ROCM,
|
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
|
||||||
)
|
-1, 1, layer.tp_v_head_num, layer.v_head_dim
|
||||||
|
),
|
||||||
|
self.scale,
|
||||||
|
self.forward_metadata.kv_indptr,
|
||||||
|
self.forward_metadata.kv_indices,
|
||||||
|
self.kv_last_page_len,
|
||||||
|
1,
|
||||||
|
self.max_num_partitions,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
"NHD",
|
||||||
|
self.logits_soft_cap,
|
||||||
|
self.k_scale,
|
||||||
|
self.v_scale,
|
||||||
|
None,
|
||||||
|
_AITER_PARTITION_SIZE_ROCM,
|
||||||
|
)
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill:
|
|||||||
spec_info.generate_attn_arg_prefill(
|
spec_info.generate_attn_arg_prefill(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
None,
|
paged_kernel_lens_sum,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_indices = kv_indices
|
self.kv_indices = kv_indices
|
||||||
|
|
||||||
|
|
||||||
|
class AiterMlaIndicesUpdaterPrefill:
|
||||||
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||||
|
# Parse Constants
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
|
# Buffers and wrappers
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
self.update = self.update_single_wrapper
|
||||||
|
|
||||||
|
self.kv_indptr = None
|
||||||
|
self.kv_indices = None
|
||||||
|
self.qo_indptr = None
|
||||||
|
self.kv_last_page_len = None
|
||||||
|
self.max_extend_len = 0
|
||||||
|
self.max_prefix_extend_len = 0
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
prefix_lens: torch.Tensor,
|
||||||
|
prefix_lens_sum: int,
|
||||||
|
extend_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
):
|
||||||
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def update_single_wrapper(
|
||||||
|
self,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
prefix_lens: torch.Tensor,
|
||||||
|
prefix_lens_sum: int,
|
||||||
|
extend_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
spec_info: Optional[SpecInfo],
|
||||||
|
):
|
||||||
|
|
||||||
|
paged_kernel_lens = prefix_lens
|
||||||
|
paged_kernel_lens_sum = prefix_lens_sum
|
||||||
|
|
||||||
|
bs = len(req_pool_indices)
|
||||||
|
|
||||||
|
kv_indptr = self.attn_backend.kv_indptr
|
||||||
|
|
||||||
|
if spec_info is None:
|
||||||
|
# Normal extend
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
|
kv_indices = torch.empty(
|
||||||
|
paged_kernel_lens_sum,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=req_pool_indices.device,
|
||||||
|
)
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
qo_indptr = self.attn_backend.qo_indptr
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
|
||||||
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
|
|
||||||
|
max_extend_len = torch.max(extend_lens).item()
|
||||||
|
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
|
||||||
|
kv_indptr += qo_indptr
|
||||||
|
else:
|
||||||
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
|
spec_info.generate_attn_arg_prefill(
|
||||||
|
req_pool_indices,
|
||||||
|
paged_kernel_lens,
|
||||||
|
paged_kernel_lens_sum,
|
||||||
|
self.req_to_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_indptr = kv_indptr
|
||||||
|
self.kv_indices = kv_indices
|
||||||
|
self.qo_indptr = qo_indptr
|
||||||
|
self.max_extend_len = max_extend_len
|
||||||
|
self.max_prefix_extend_len = max_prefix_extend_len
|
||||||
|
|||||||
@@ -20,10 +20,11 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import is_cuda, is_hip
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -33,7 +34,10 @@ if _is_cuda:
|
|||||||
rmsnorm,
|
rmsnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_hip:
|
if _use_aiter:
|
||||||
|
from aiter import rmsnorm2d_fwd as rms_norm
|
||||||
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
||||||
|
elif _is_hip:
|
||||||
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -48,6 +52,8 @@ class RMSNorm(CustomOp):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
if _use_aiter:
|
||||||
|
self._forward_method = self.forward_aiter
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
@@ -60,6 +66,25 @@ class RMSNorm(CustomOp):
|
|||||||
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def forward_aiter(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if residual is not None:
|
||||||
|
residual_out = torch.empty_like(x)
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
fused_add_rms_norm(
|
||||||
|
output,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
residual_out,
|
||||||
|
self.weight.data,
|
||||||
|
self.variance_epsilon,
|
||||||
|
)
|
||||||
|
return output, residual_out
|
||||||
|
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
def forward_hip(
|
def forward_hip(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@@ -1332,7 +1332,7 @@ def fused_experts_impl(
|
|||||||
if (
|
if (
|
||||||
not (use_fp8_w8a8 or use_int8_w8a8)
|
not (use_fp8_w8a8 or use_int8_w8a8)
|
||||||
or block_shape is not None
|
or block_shape is not None
|
||||||
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
|
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
|
||||||
):
|
):
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
|
|
||||||
|
|||||||
@@ -28,8 +28,9 @@ else:
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _is_hip:
|
if _use_aiter:
|
||||||
from aiter import ActivationType
|
from aiter import ActivationType
|
||||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
if _use_aiter:
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
if _use_aiter:
|
||||||
assert not no_combine, "unsupported"
|
assert not no_combine, "unsupported"
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -77,8 +77,8 @@ _is_cuda = is_cuda()
|
|||||||
|
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
||||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from aiter import ActivationType, QuantType
|
from aiter import ActivationType, QuantType
|
||||||
@@ -487,7 +487,7 @@ class Fp8MoEMethod:
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
|
params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
block_n, block_k = (
|
block_n, block_k = (
|
||||||
@@ -512,7 +512,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
if _is_hip and use_hip_int4:
|
if _is_hip and _use_hip_int4:
|
||||||
# INT4 MoE weight - INT32 packed
|
# INT4 MoE weight - INT32 packed
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@@ -641,7 +641,7 @@ class Fp8MoEMethod:
|
|||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
|
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
|
||||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||||
w13_weight_scale1 = torch.nn.Parameter(
|
w13_weight_scale1 = torch.nn.Parameter(
|
||||||
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
|
||||||
@@ -668,7 +668,7 @@ class Fp8MoEMethod:
|
|||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
if _is_hip and use_hip_int4:
|
if _is_hip and _use_hip_int4:
|
||||||
extra_weight_attrs.update(
|
extra_weight_attrs.update(
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
)
|
)
|
||||||
@@ -700,7 +700,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if _is_hip and use_hip_int4:
|
if _is_hip and _use_hip_int4:
|
||||||
self.process_weights_hip_int4(layer)
|
self.process_weights_hip_int4(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -731,7 +731,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
if _is_hip and use_aiter_moe:
|
if _use_aiter:
|
||||||
# Pre-shuffle weights
|
# Pre-shuffle weights
|
||||||
layer.w13_weight.data = shuffle_weight(
|
layer.w13_weight.data = shuffle_weight(
|
||||||
layer.w13_weight.contiguous(), (16, 16)
|
layer.w13_weight.contiguous(), (16, 16)
|
||||||
@@ -853,7 +853,7 @@ class Fp8MoEMethod:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def process_weights_hip_int4(self, layer: Module):
|
def process_weights_hip_int4(self, layer: Module):
|
||||||
# TODO: and use_aiter_moe: add after triton kernel added
|
# TODO: _use_aiter: add after triton kernel added
|
||||||
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
|
||||||
# Weight Permutation
|
# Weight Permutation
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
@@ -900,7 +900,7 @@ class Fp8MoEMethod:
|
|||||||
padding_size, # Avoid circular import
|
padding_size, # Avoid circular import
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_aiter_moe:
|
if _use_aiter:
|
||||||
layer.w13_weight = torch.nn.Parameter(
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -911,7 +911,7 @@ class Fp8MoEMethod:
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# ROCm (use_aiter_moe): using column-wise scaling
|
# ROCm (_use_aiter): using column-wise scaling
|
||||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||||
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
elif get_bool_env_var("SGLANG_MOE_PADDING"):
|
||||||
@@ -1041,8 +1041,8 @@ class Fp8MoEMethod:
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
if use_hip_int4:
|
if _use_hip_int4:
|
||||||
# TODO: add triton kernel and add check use_aiter_moe
|
# TODO: add triton kernel and add check _use_aiter
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
return ck_moe_2stages(
|
return ck_moe_2stages(
|
||||||
x,
|
x,
|
||||||
@@ -1058,13 +1058,13 @@ class Fp8MoEMethod:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_aiter_moe:
|
if _use_aiter:
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
|
# TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
|
||||||
assert (
|
assert (
|
||||||
activation == "silu"
|
activation == "silu"
|
||||||
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
|
), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
|
||||||
return asm_moe(
|
return asm_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@@ -38,11 +38,10 @@ _is_hip = is_hip()
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
if _use_aiter:
|
||||||
|
from aiter import gemm_a8w8_blockscale_CK
|
||||||
if _is_hip and use_aiter_moe:
|
|
||||||
from aiter import gemm_a8w8_blockscale
|
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
||||||
@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
|
|||||||
return flashinfer_gemm_w8a8_block_fp8_linear
|
return flashinfer_gemm_w8a8_block_fp8_linear
|
||||||
elif CUTLASS_BLOCK_FP8_SUPPORTED:
|
elif CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||||
return cutlass_w8a8_block_fp8_linear_with_fallback
|
return cutlass_w8a8_block_fp8_linear_with_fallback
|
||||||
elif _is_hip and use_aiter_moe:
|
elif _use_aiter:
|
||||||
return aiter_w8a8_block_fp8_linear
|
return aiter_w8a8_block_fp8_linear
|
||||||
elif _ENABLE_JIT_DEEPGEMM:
|
elif _ENABLE_JIT_DEEPGEMM:
|
||||||
return deepgemm_w8a8_block_fp8_linear_with_fallback
|
return deepgemm_w8a8_block_fp8_linear_with_fallback
|
||||||
@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear(
|
|||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=False
|
input_2d, block_size[1], column_major_scales=False
|
||||||
)
|
)
|
||||||
output = torch.zeros(
|
output = gemm_a8w8_blockscale_CK(
|
||||||
[q_input.shape[0], weight.shape[0]],
|
q_input, weight, x_scale, weight_scale, dtype=input.dtype
|
||||||
dtype=input_2d.dtype,
|
|
||||||
device=q_input.device,
|
|
||||||
)
|
)
|
||||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output += bias
|
output += bias
|
||||||
|
|||||||
@@ -355,6 +355,15 @@ class ModelRunner:
|
|||||||
# MLA architecture
|
# MLA architecture
|
||||||
if is_hopper_with_cuda_12_3():
|
if is_hopper_with_cuda_12_3():
|
||||||
server_args.attention_backend = "fa3"
|
server_args.attention_backend = "fa3"
|
||||||
|
elif _is_hip:
|
||||||
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||||
|
# TODO current aiter only support head number 16 or 128 head number
|
||||||
|
if (
|
||||||
|
head_num == 128 or head_num == 16
|
||||||
|
) and self.spec_algorithm.is_none():
|
||||||
|
server_args.attention_backend = "aiter"
|
||||||
|
else:
|
||||||
|
server_args.attention_backend = "triton"
|
||||||
else:
|
else:
|
||||||
server_args.attention_backend = "triton"
|
server_args.attention_backend = "triton"
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -363,6 +372,7 @@ class ModelRunner:
|
|||||||
elif self.use_mla_backend:
|
elif self.use_mla_backend:
|
||||||
if server_args.device != "cpu":
|
if server_args.device != "cpu":
|
||||||
if server_args.attention_backend in [
|
if server_args.attention_backend in [
|
||||||
|
"aiter",
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"fa3",
|
"fa3",
|
||||||
"triton",
|
"triton",
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ from sglang.srt.utils import (
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||||
@@ -120,6 +121,9 @@ if _is_hip:
|
|||||||
decode_attention_fwd_grouped_rope,
|
decode_attention_fwd_grouped_rope,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
from aiter.rotary_embedding import get_rope
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.alt_stream = alt_stream
|
self.alt_stream = alt_stream
|
||||||
|
self.attn_mha.kv_b_proj = None
|
||||||
|
|
||||||
self.w_kc = None
|
self.w_kc = None
|
||||||
self.w_vc = None
|
self.w_vc = None
|
||||||
@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
else:
|
else:
|
||||||
return _dispatch_mla_subtype()
|
return _dispatch_mla_subtype()
|
||||||
|
elif self.attention_backend == "aiter":
|
||||||
|
if (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
else:
|
else:
|
||||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||||
if (
|
if (
|
||||||
@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
):
|
):
|
||||||
|
if self.attn_mha.kv_b_proj is None:
|
||||||
|
self.attn_mha.kv_b_proj = self.kv_b_proj
|
||||||
|
|
||||||
if hidden_states.shape[0] == 0:
|
if hidden_states.shape[0] == 0:
|
||||||
assert (
|
assert (
|
||||||
not self.o_proj.reduce_results
|
not self.o_proj.reduce_results
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# Default working directory
|
|
||||||
WORKDIR="/sglang-checkout/test/srt"
|
WORKDIR="/sglang-checkout/test/srt"
|
||||||
ENV_ARGS=(
|
declare -A ENV_MAP=(
|
||||||
-e SGLANG_AMD_CI=1
|
[SGLANG_AMD_CI]=1
|
||||||
-e SGLANG_IS_IN_CI=1
|
[SGLANG_IS_IN_CI]=1
|
||||||
-e SGLANG_AITER_MOE=1
|
[SGLANG_USE_AITER]=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse optional -w/--workdir and -e ENV=VAL flags
|
# Parse -w/--workdir and -e ENV=VAL
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
case "$1" in
|
case "$1" in
|
||||||
-w|--workdir)
|
-w|--workdir)
|
||||||
@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do
|
|||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
-e)
|
-e)
|
||||||
ENV_ARGS+=("-e" "$2")
|
IFS="=" read -r key val <<< "$2"
|
||||||
|
ENV_MAP["$key"]="$val"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--)
|
--)
|
||||||
@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do
|
|||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# Build final ENV_ARGS
|
||||||
|
ENV_ARGS=()
|
||||||
|
for key in "${!ENV_MAP[@]}"; do
|
||||||
|
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
|
||||||
|
done
|
||||||
|
|
||||||
# Run docker exec
|
# Run docker exec
|
||||||
docker exec \
|
docker exec \
|
||||||
-w "$WORKDIR" \
|
-w "$WORKDIR" \
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase):
|
|||||||
os.environ["HF_HUB_DISABLE_XET"] = (
|
os.environ["HF_HUB_DISABLE_XET"] = (
|
||||||
"1" if model in DISABLE_HF_XET_MODELS else "0"
|
"1" if model in DISABLE_HF_XET_MODELS else "0"
|
||||||
)
|
)
|
||||||
os.environ["SGLANG_AITER_MOE"] = (
|
os.environ["SGLANG_USE_AITER"] = (
|
||||||
"0" if model in TRITON_MOE_MODELS else "1"
|
"0" if model in TRITON_MOE_MODELS else "1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user