diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 286618d6b..e8dab5c80 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -55,7 +55,7 @@ _is_npu = is_npu() @dataclass class GraphCaptureContext: - stream: torch.cuda.Stream + stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) @@ -252,8 +252,11 @@ class GroupCoordinator: if is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") + elif _is_npu: + self.device = torch.device(f"npu:{local_rank}") else: self.device = torch.device("cpu") + self.device_module = torch.get_device_module(self.device) self.use_pynccl = use_pynccl self.use_pymscclpp = use_pymscclpp @@ -402,7 +405,7 @@ class GroupCoordinator: self, graph_capture_context: Optional[GraphCaptureContext] = None ): if graph_capture_context is None: - stream = torch.cuda.Stream() + stream = self.device_module.Stream() graph_capture_context = GraphCaptureContext(stream) else: stream = graph_capture_context.stream @@ -413,11 +416,11 @@ class GroupCoordinator: # ensure all initialization operations complete before attempting to # capture the graph on another stream - curr_stream = torch.cuda.current_stream() + curr_stream = self.device_module.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with torch.cuda.stream(stream), maybe_ca_context: + with self.device_module.stream(stream), maybe_ca_context: # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | @@ -1641,6 +1644,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ) elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() + elif hasattr(torch, "npu") and torch.npu.is_available(): + torch.npu.empty_cache() def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 020f04dcd..c1f4c2785 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional import torch import torch_npu @@ -27,6 +27,7 @@ class ForwardMetadata: # seq len inputs extend_seq_lens_cpu_int: Optional[torch.Tensor] = None seq_lens_cpu_int: Optional[torch.Tensor] = None + seq_lens_cpu_list: Optional[List[int]] = None class AscendAttnBackend(AttentionBackend): @@ -51,7 +52,7 @@ class AscendAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - self.forward_metadata = ForwardMetadata() + self.forward_metadata = None self.device = model_runner.device self.gen_attention_mask(128, model_runner.dtype) self.page_size = model_runner.page_size @@ -60,9 +61,15 @@ class AscendAttnBackend(AttentionBackend): self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim self.native_attn = TorchNativeAttnBackend(model_runner) + self.graph_metadata = {} + self.max_context_len = model_runner.model_config.context_len + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.graph_mode = False def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" + self.forward_metadata = ForwardMetadata() + self.forward_metadata.block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : forward_batch.seq_lens.max() @@ -75,6 +82,63 @@ class AscendAttnBackend(AttentionBackend): ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() + self.graph_mode = False + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + self.graph_metadata = { + "block_tables": torch.empty( + (max_bs, self.max_context_len // self.page_size), + dtype=torch.int32, + device=self.device, + ), + } + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + metadata = ForwardMetadata() + + metadata.block_tables = self.graph_metadata["block_tables"][:bs, :] + metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist() + + self.graph_metadata[bs] = metadata + self.forward_metadata = metadata + + self.graph_mode = True + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + metadata = self.graph_metadata[bs] + max_len = seq_lens_cpu[:bs].max().item() + max_seq_pages = (max_len + self.page_size - 1) // self.page_size + + metadata.block_tables[:bs, :max_seq_pages].copy_( + self.req_to_token[req_pool_indices[:bs], :max_len][:, :: self.page_size] + // self.page_size + ) + metadata.block_tables[:bs, max_seq_pages:].fill_(0) + metadata.block_tables[bs:, :].fill_(0) + + self.forward_metadata = metadata + + self.graph_mode = True + def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -167,28 +231,74 @@ class AscendAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) if not self.use_mla: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + if self.graph_mode: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) + query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) + num_tokens = query.shape[0] + workspace = ( + torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query, + k_cache, + v_cache, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSH", + scale=layer.scaling, + actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, + ) + ) + output = torch.empty( + (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), + dtype=q.dtype, + device=q.device, + ) + softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) + torch_npu.npu_fused_infer_attention_score.out( + query, + k_cache, + v_cache, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSH", + scale=layer.scaling, + actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list, + workspace=workspace, + out=[output, softmax_lse], + ) + else: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ) - query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) - num_tokens = query.shape[0] - output = torch.empty( - (num_tokens, layer.tp_q_head_num, layer.v_head_dim), - dtype=query.dtype, - device=query.device, - ) + query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + num_tokens = query.shape[0] + output = torch.empty( + (num_tokens, layer.tp_q_head_num, layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) - torch_npu._npu_paged_attention( - query=query, - key_cache=k_cache, - value_cache=v_cache, - num_heads=layer.tp_q_head_num, - num_kv_heads=layer.tp_k_head_num, - scale_value=layer.scaling, - block_table=self.forward_metadata.block_tables, - context_lens=self.forward_metadata.seq_lens_cpu_int, - out=output, - ) + torch_npu._npu_paged_attention( + query=query, + key_cache=k_cache, + value_cache=v_cache, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + scale_value=layer.scaling, + block_table=self.forward_metadata.block_tables, + context_lens=self.forward_metadata.seq_lens_cpu_int, + out=output, + ) return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) else: query = q.view(-1, layer.tp_q_head_num, layer.head_dim) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index cc87910ac..abf95d4d0 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -240,6 +240,8 @@ class CudaGraphRunner: def __init__(self, model_runner: ModelRunner): # Parse args self.model_runner = model_runner + self.device = model_runner.device + self.device_module = torch.get_device_module(self.device) self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile @@ -305,13 +307,15 @@ class CudaGraphRunner: self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) # Graph inputs - with torch.device("cuda"): + with torch.device(self.device): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.out_cache_loc = torch.zeros( + (self.max_num_token,), dtype=self._cache_loc_dtype() + ) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) @@ -366,12 +370,12 @@ class CudaGraphRunner: * self.num_tokens_per_bs ), dtype=torch.bool, - device="cuda", + device=self.device, ) self.next_token_logits_buffer = torch.zeros( (self.max_num_token, self.model_runner.model_config.vocab_size), dtype=torch.float, - device="cuda", + device=self.device, ) # Capture @@ -383,6 +387,9 @@ class CudaGraphRunner: f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) + def _cache_loc_dtype(self): + return torch.int64 + def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( @@ -502,8 +509,16 @@ class CudaGraphRunner: ) logger.info(log_message) + def _capture_graph(self, graph, pool, stream, run_once_fn): + with self.device_module.graph(graph, pool=pool, stream=stream): + out = run_once_fn() + return out + + def _create_device_graph(self): + return torch.cuda.CUDAGraph() + def capture_one_batch_size(self, bs: int, forward: Callable): - graph = torch.cuda.CUDAGraph() + graph = self._create_device_graph() stream = self.stream num_tokens = bs * self.num_tokens_per_bs @@ -643,19 +658,17 @@ class CudaGraphRunner: return logits_output_or_pp_proxy_tensors for _ in range(2): - torch.cuda.synchronize() + self.device_module.synchronize() self.model_runner.tp_group.barrier() - run_once() if get_global_graph_memory_pool() is None: - set_global_graph_memory_pool(torch.cuda.graph_pool_handle()) + set_global_graph_memory_pool(self.device_module.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) - with torch.cuda.graph( - graph, pool=get_global_graph_memory_pool(), stream=stream - ): - out = run_once() + out = self._capture_graph( + graph, get_global_graph_memory_pool(), stream, run_once + ) return graph, out diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a30fb897f..a20571253 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -91,6 +91,7 @@ from sglang.srt.mem_cache.memory_pool import ( ) from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.utils import set_default_torch_dtype @@ -341,9 +342,12 @@ class ModelRunner: if self.device == "cuda": self.init_cublas() self.init_attention_backend() - self.init_cuda_graphs() + self.init_device_graphs() + elif self.device == "npu": + self.init_attention_backend() + self.init_device_graphs() else: - self.cuda_graph_runner = None + self.graph_runner = None self.cuda_graph_mem_usage = 0 self.init_attention_backend() @@ -917,7 +921,8 @@ class ModelRunner: ) # We need to get device after patch otherwise the device would be wrong - infered_device = torch.cuda.current_device() + self.device_module = torch.get_device_module(self.device) + infered_device = self.device_module.current_device() named_tensors = [ (name, _unwrap_tensor(tensor, tp_rank=self.tp_rank, device=infered_device)) @@ -1592,9 +1597,9 @@ class ModelRunner: .cuda() ) - def init_cuda_graphs(self): + def init_device_graphs(self): """Capture cuda graphs.""" - self.cuda_graph_runner = None + self.graph_runner = None self.cuda_graph_mem_usage = 0 if not self.is_generation: @@ -1609,8 +1614,9 @@ class ModelRunner: logger.info( f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - self.cuda_graph_runner = CudaGraphRunner(self) - + self.graph_runner = ( + CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self) + ) after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.cuda_graph_mem_usage = before_mem - after_mem logger.info( @@ -1762,11 +1768,11 @@ class ModelRunner: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() - and self.cuda_graph_runner - and self.cuda_graph_runner.can_run(forward_batch) + and self.graph_runner + and self.graph_runner.can_run(forward_batch) ) if can_run_cuda_graph: - ret = self.cuda_graph_runner.replay( + ret = self.graph_runner.replay( forward_batch, skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py new file mode 100644 index 000000000..0ff19d582 --- /dev/null +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -0,0 +1,94 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run the model with npu graph and torch.compile.""" + +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + + +class NPUGraphRunner(CudaGraphRunner): + """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.""" + + def __init__(self, model_runner: ModelRunner): + super().__init__(model_runner) + + def _create_device_graph(self): + return torch.npu.NPUGraph() + + def _capture_graph(self, graph, pool, stream, run_once_fn): + with torch.npu.graph( + graph, + pool=pool, + stream=stream, + auto_dispatch_capture=True, + ): + out = run_once_fn() + return out + + def _update_inputs(self, seq_lens): + self.graphs[self.bs].update( + cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] + ) + + def _cache_loc_dtype(self): + return torch.int32 + + def replay( + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + if not skip_attn_backend_init: + self.replay_prepare(forward_batch, pp_proxy_tensors) + else: + # In speculative decoding, these two fields are still needed. + self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) + self.positions[: self.raw_num_token].copy_(forward_batch.positions) + + # Replay + seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs) + thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) + thread.start() + self.graphs[self.bs].replay() + thread.join() + + output = self.output_buffers[self.bs] + if isinstance(output, LogitsProcessorOutput): + return LogitsProcessorOutput( + next_token_logits=output.next_token_logits[: self.raw_num_token], + hidden_states=( + output.hidden_states[: self.raw_num_token] + if output.hidden_states is not None + else None + ), + ) + else: + assert isinstance(output, PPProxyTensors) + return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) diff --git a/test/srt/ascend/test_ascend_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_graph_tp1_bf16.py new file mode 100644 index 000000000..95c6b7bcf --- /dev/null +++ b/test/srt/ascend/test_ascend_graph_tp1_bf16.py @@ -0,0 +1,95 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +TEST_MODEL_MATRIX = { + "Qwen/Qwen2.5-7B-Instruct": { + "accuracy": 0.85, + "latency": 150, + "output_throughput": 30, + }, +} + + +class TestAscendGraphTp1Bf16(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.models = TEST_MODEL_MATRIX.keys() + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.common_args = [ + "--trust-remote-code", + "--mem-fraction-static", + 0.8, + "--attention-backend", + "ascend", + ] + + def test_a_gsm8k(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing accuracy: {model} ===##") + + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *self.common_args, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + TEST_MODEL_MATRIX[model]["accuracy"], + ) + finally: + kill_process_tree(process.pid) + + def test_b_throughput(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing throughput: {model} ===##") + + output_throughput = run_bench_offline_throughput( + model, + [ + *self.common_args, + ], + ) + + print(f"##=== {model} throughput: {output_throughput} ===##") + + if is_in_ci(): + self.assertGreater( + output_throughput, + TEST_MODEL_MATRIX[model]["output_throughput"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/ascend/test_ascend_graph_tp2_bf16.py b/test/srt/ascend/test_ascend_graph_tp2_bf16.py new file mode 100644 index 000000000..f7c3c6537 --- /dev/null +++ b/test/srt/ascend/test_ascend_graph_tp2_bf16.py @@ -0,0 +1,97 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +TEST_MODEL_MATRIX = { + "Qwen/Qwen2.5-7B-Instruct": { + "accuracy": 0.85, + "latency": 180, + "output_throughput": 20, + }, +} + + +class TestAscendGraphTp2Bf16(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.models = TEST_MODEL_MATRIX.keys() + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.common_args = [ + "--trust-remote-code", + "--mem-fraction-static", + 0.8, + "--attention-backend", + "ascend", + "--tp-size", + 2, + ] + + def test_a_gsm8k(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing accuracy: {model} ===##") + + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *self.common_args, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + TEST_MODEL_MATRIX[model]["accuracy"], + ) + finally: + kill_process_tree(process.pid) + + def test_b_throughput(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing throughput: {model} ===##") + + output_throughput = run_bench_offline_throughput( + model, + [ + *self.common_args, + ], + ) + + print(f"##=== {model} throughput: {output_throughput} ===##") + + if is_in_ci(): + self.assertGreater( + output_throughput, + TEST_MODEL_MATRIX[model]["output_throughput"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_ascend_w8a8_quantization.py b/test/srt/ascend/test_ascend_w8a8_quantization.py similarity index 100% rename from test/srt/test_ascend_w8a8_quantization.py rename to test/srt/ascend/test_ascend_w8a8_quantization.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b948bc82e..4c98dc585 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -269,9 +269,11 @@ suite_xeon = { suite_ascend = { "per-commit-1-ascend-npu": [ TestFile("ascend/test_ascend_tp1_bf16.py", 400), + TestFile("ascend/test_ascend_graph_tp1_bf16.py", 400), ], "per-commit-2-ascend-npu": [ TestFile("ascend/test_ascend_tp2_bf16.py", 400), + TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400), ], "per-commit-4-ascend-npu": [ TestFile("ascend/test_ascend_mla_w8a8int8.py", 400),