diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 2f7bba308..c2c73a923 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -44,7 +44,7 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Evaluate Accuracy - timeout-minutes: 20 + timeout-minutes: 30 run: | bash scripts/amd_ci_exec.sh python3 test_eval_accuracy_large.py bash scripts/amd_ci_exec.sh python3 test_eval_fp8_accuracy.py @@ -70,7 +70,7 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Evaluate accuracy (TP=2) - timeout-minutes: 20 + timeout-minutes: 30 run: | bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py @@ -94,7 +94,7 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: MLA TEST - timeout-minutes: 20 + timeout-minutes: 30 run: | bash scripts/amd_ci_exec.sh python3 test_mla.py @@ -118,28 +118,28 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Benchmark single latency - timeout-minutes: 10 + timeout-minutes: 20 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_small bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1_default - name: Benchmark online latency - timeout-minutes: 10 + timeout-minutes: 15 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_default - name: Benchmark offline throughput - timeout-minutes: 10 + timeout-minutes: 15 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default - name: Benchmark offline throughput (Non-streaming, small batch size) - timeout-minutes: 10 + timeout-minutes: 15 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size - name: Benchmark online latency (EAGLE) - timeout-minutes: 10 + timeout-minutes: 15 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_eagle @@ -163,17 +163,17 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Benchmark offline throughput (w/o RadixAttention) - timeout-minutes: 10 + timeout-minutes: 15 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache - name: Benchmark offline throughput (w/ Triton) - timeout-minutes: 10 + timeout-minutes: 15 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend - name: Benchmark offline throughput (w/ FP8) - timeout-minutes: 10 + timeout-minutes: 15 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8 @@ -197,27 +197,27 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Benchmark dummy grok (TP=2) - timeout-minutes: 20 + timeout-minutes: 30 run: | bash scripts/amd_ci_exec.sh python3 models/test_dummy_grok_models.py - name: Benchmark single latency (TP=2) - timeout-minutes: 20 + timeout-minutes: 25 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1 - name: Benchmark single latency + torch.compile (TP=2) - timeout-minutes: 20 + timeout-minutes: 25 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1 - name: Benchmark offline throughput (TP=2) - timeout-minutes: 20 + timeout-minutes: 25 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_default - name: Benchmark offline throughput (w/o RadixAttention) (TP=2) - timeout-minutes: 20 + timeout-minutes: 25 run: | bash scripts/amd_ci_exec.sh python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache @@ -241,7 +241,7 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Run test - timeout-minutes: 30 + timeout-minutes: 40 run: | bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-amd @@ -265,7 +265,7 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Run test - timeout-minutes: 30 + timeout-minutes: 40 run: | bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-2-gpu-amd @@ -289,7 +289,7 @@ jobs: run: bash scripts/amd_ci_install_dependency.sh - name: Run test - timeout-minutes: 30 + timeout-minutes: 40 run: | bash scripts/amd_ci_exec.sh python3 run_suite.py --suite per-commit-8-gpu-amd diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f014f0cef..a1106d914 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -18,7 +18,7 @@ ARG TRITON_COMMIT="improve_fa_decode_3.0.0" ARG AITER_REPO="https://github.com/ROCm/aiter.git" -ARG AITER_COMMIT="v0.1.1" +ARG AITER_COMMIT="v0.1.2" RUN git clone ${SGL_REPO} \ && cd sglang \ diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py new file mode 100644 index 000000000..6adb276da --- /dev/null +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -0,0 +1,513 @@ +from __future__ import annotations + +""" +end to end attention solution with aiter kernels +""" + +import math +import os +from dataclasses import dataclass +from enum import Enum, auto +from functools import partial +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +import triton +import triton.language as tl + +from sglang.global_config import global_config +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo + +try: + from aiter import mha_batch_prefill_func, paged_attention_ragged +except ImportError: + print( + "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." + ) + + +class WrapperDispatch(Enum): + SLIDING_WINDOW = auto() + CROSS_ATTENTION = auto() + + +@dataclass +class ForwardMetadata: + kv_indptr: torch.Tensor + kv_indices: torch.Tensor + max_q_len: int + max_kv_len: int + + +global_workspace_buffer = None + +_AITER_PARTITION_SIZE_ROCM = 256 + + +class AiterAttnBackend(AttentionBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): + super().__init__() + + self.device = model_runner.device + self.is_multimodal = model_runner.model_config.is_multimodal + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.head_dim = model_runner.model_config.head_dim + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + self.num_kv_head = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + self.kv_cache_dtype = model_runner.kv_cache_dtype + + self.req_to_token = model_runner.req_to_token_pool.req_to_token + + # Parse constants + self.max_context_len = model_runner.model_config.context_len + self.skip_prefill = skip_prefill + + max_bs = model_runner.req_to_token_pool.size + + if kv_indptr_buf is None: + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + else: + self.kv_indptr = kv_indptr_buf + + self.kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device=model_runner.device + ) + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + + # Create prefill indices updater + if not skip_prefill: + self.indices_updater_prefill = AiterIndicesUpdaterPrefill( + model_runner, self + ) + + # aiter kernel related initialization + self.max_num_partitions = ( + self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + ) // _AITER_PARTITION_SIZE_ROCM + + nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 + + self.workspace_buffer = torch.empty( + (max_bs * self.num_head * self.max_num_partitions * self.head_dim) + * nbyes_per_qo_elem + + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, + dtype=torch.uint8, + device=self.device, + ) + + self.scale = float(1.0 / (self.head_dim**0.5)) + self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( + self.device + ) + self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to( + self.device + ) + + self.logits_soft_cap = 0.0 + + self.forward_metadata: ForwardMetadata = None + + def init_forward_metadata(self, forward_batch: ForwardBatch): + if forward_batch.forward_mode.is_decode_or_idle(): + # update for aiter + # create kv_indices and kv_inptr + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + spec_info = forward_batch.spec_info + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.zeros( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 + + self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None) + + elif forward_batch.forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) + elif forward_batch.forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) + else: + prefix_lens = forward_batch.extend_prefix_lens + + if self.is_multimodal: + extend_no_prefix = False + else: + extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens, + encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) + + def init_cuda_graph_state( + self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + ): + if kv_indices_buf is None: + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_kv_indices = kv_indices_buf + + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device=self.device, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + if forward_mode.is_decode_or_idle(): + if spec_info is None: + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None) + + elif forward_mode.is_target_verify(): + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + encoder_lens=encoder_lens, + spec_info=spec_info, + ) + self.forward_metadata = ForwardMetadata( + self.indices_updater_prefill.kv_indptr, + self.indices_updater_prefill.kv_indices, + self.indices_updater_prefill.max_q_len, + self.indices_updater_prefill.max_kv_len, + ) + + else: + raise ValueError(f"Invalid mode: {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + if forward_mode.is_decode_or_idle(): + kv_indptr = self.kv_indptr + kv_indices = self.cuda_graph_kv_indices + if spec_info is None: + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + else: + kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr + kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices + + elif forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + spec_info=spec_info, + ) + else: + raise ValueError("Invalid forward mode") + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + self.logits_soft_cap = layer.logit_cap + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + + bs0 = forward_batch.batch_size + 1 + + o = mha_batch_prefill_func( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k_cache, + v_cache, + self.qo_indptr[:bs0], + self.forward_metadata.kv_indptr[:bs0], + self.forward_metadata.kv_indices, + self.forward_metadata.max_q_len, + self.forward_metadata.max_kv_len, + causal=True, + logits_soft_cap=self.logits_soft_cap, + alibi_slopes=None, + return_lse=False, + return_attn_probs=False, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, forward_batch.out_cache_loc, k, v + ) + + self.logits_soft_cap = layer.logit_cap + paged_attention_ragged( + o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + self.workspace_buffer, + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( + -1, 1, layer.tp_k_head_num, layer.qk_head_dim + ), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( + -1, 1, layer.tp_v_head_num, layer.v_head_dim + ), + self.scale, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.kv_last_page_lens, + 1, + self.max_num_partitions, + None, + "auto", + "NHD", + self.logits_soft_cap, + self.k_scale, + self.v_scale, + None, + _AITER_PARTITION_SIZE_ROCM, + ) + + return o + + +class AiterIndicesUpdaterPrefill: + def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): + # Parse Constants + self.num_qo_heads = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) + self.num_kv_heads = model_runner.model_config.get_num_kv_heads( + get_attention_tp_size() + ) + self.head_dim = model_runner.model_config.head_dim + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.sliding_window_size = model_runner.sliding_window_size + self.attn_backend = attn_backend + + # Buffers and wrappers + self.kv_indptr = attn_backend.kv_indptr + self.kv_last_page_len = attn_backend.kv_last_page_len + self.qo_indptr = attn_backend.qo_indptr + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.update = self.update_single_wrapper + + self.kv_indices = None + self.max_q_len = 0 + self.max_kv_len = 0 + + def update( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + prefix_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], + ): + # Keep the signature for type checking. It will be assigned during runtime. + raise NotImplementedError() + + def update_single_wrapper( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + prefix_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], + ): + + kv_start_idx = None + kv_indptr = self.kv_indptr + qo_indptr = self.qo_indptr + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + + bs = len(req_pool_indices) + if spec_info is None: + # Normal extend + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.req_to_token.shape[1], + ) + + self.max_kv_len = torch.max(paged_kernel_lens).item() + + extend_lens = seq_lens - prefix_lens + self.max_q_len = torch.max(extend_lens).item() + + qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + else: + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) + ) + + self.kv_indices = kv_indices diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5b4614585..5fcc33865 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -103,6 +103,8 @@ from sglang.srt.utils import ( set_cuda_arch, ) +_is_hip = is_hip() + # Use a small KV cache pool size for tests in CI SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) @@ -318,6 +320,8 @@ class ModelRunner: and is_fa3_default_architecture(self.model_config.hf_config) ): server_args.attention_backend = "fa3" + elif _is_hip: + server_args.attention_backend = "aiter" else: server_args.attention_backend = ( "flashinfer" if is_flashinfer_available() else "triton" @@ -794,7 +798,7 @@ class ModelRunner: if self.server_args.kv_cache_dtype == "auto": self.kv_cache_dtype = self.dtype elif self.server_args.kv_cache_dtype == "fp8_e5m2": - if is_hip(): # Using natively supported format + if _is_hip: # Using natively supported format self.kv_cache_dtype = torch.float8_e5m2fnuz else: self.kv_cache_dtype = torch.float8_e5m2 @@ -972,6 +976,10 @@ class ModelRunner: ) self.attn_backend = FlashInferMLAAttnBackend(self) + elif self.server_args.attention_backend == "aiter": + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + + self.attn_backend = AiterAttnBackend(self) elif self.server_args.attention_backend == "triton": assert self.sliding_window_size is None, ( "Window attention is not supported in the triton attention backend. " diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f0c862cc4..ed9e92641 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -957,6 +957,7 @@ class ServerArgs: "--attention-backend", type=str, choices=[ + "aiter", "flashinfer", "triton", "torch_native", diff --git a/scripts/amd_ci_install_dependency.sh b/scripts/amd_ci_install_dependency.sh index eedbed020..12052ed6e 100755 --- a/scripts/amd_ci_install_dependency.sh +++ b/scripts/amd_ci_install_dependency.sh @@ -5,7 +5,6 @@ set -euo pipefail docker exec ci_sglang pip install --upgrade pip docker exec ci_sglang pip uninstall sgl-kernel -y || true docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install" -docker exec ci_sglang pip install -e "python[dev_hip]" docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git docker exec -w /human-eval ci_sglang pip install -e . diff --git a/scripts/amd_ci_start_container.sh b/scripts/amd_ci_start_container.sh index 30fd26d05..bf0f891ea 100755 --- a/scripts/amd_ci_start_container.sh +++ b/scripts/amd_ci_start_container.sh @@ -9,7 +9,7 @@ else fi # Pull the image -IMAGE="lmsysorg/sglang:v0.4.6.post3-rocm630" +IMAGE="ghcr.io/saienduri/sglang-aiter-backend-v0.1.2:518" echo "Pulling Docker image: $IMAGE" docker pull "$IMAGE" diff --git a/test/srt/models/test_dummy_grok_models.py b/test/srt/models/test_dummy_grok_models.py index 6f1815fae..290c49164 100644 --- a/test/srt/models/test_dummy_grok_models.py +++ b/test/srt/models/test_dummy_grok_models.py @@ -4,6 +4,7 @@ from sglang.test.test_utils import CustomTestCase, is_in_ci, run_bench_one_batch class TestDummyGrok1(CustomTestCase): + def test_dummy_grok_1(self): output_throughput = run_bench_one_batch( None, diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index efb202463..becf07a09 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -3,6 +3,8 @@ Usage: python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu """ +import os +import time import unittest from types import SimpleNamespace @@ -35,6 +37,11 @@ class TestEvalAccuracyLarge(CustomTestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) + def tearDown(self): + # Delay between tests to allow GPU memory cleanup + if os.getenv("SGLANG_AMD_CI") == "1": + time.sleep(180) + def test_mmlu(self): args = SimpleNamespace( base_url=self.base_url,