Add graph runner support with torch compile on CPU (#7843)
This commit is contained in:
2
.github/workflows/pr-test-xeon.yml
vendored
2
.github/workflows/pr-test-xeon.yml
vendored
@@ -70,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Run unit tests
|
||||
if: steps.check_amx.outcome == 'success'
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 36
|
||||
run: |
|
||||
docker exec -w /sglang-checkout/ ci_sglang_xeon \
|
||||
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
|
||||
|
||||
@@ -134,7 +134,12 @@ Notes:
|
||||
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
|
||||
```
|
||||
|
||||
3. A warmup step is automatically triggered when the service is started.
|
||||
3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`.
|
||||
To specify the maximum batch size when using torch compile, set the flag `--torch-compile-max-bs`.
|
||||
For example, `--enable-torch-compile --torch-compile-max-bs 4` means using torch compile and setting the
|
||||
maximum batch size to 4.
|
||||
|
||||
4. A warmup step is automatically triggered when the service is started.
|
||||
The server is ready when you see the log `The server is fired up and ready to roll!`.
|
||||
|
||||
## Benchmarking with Requests
|
||||
|
||||
@@ -64,6 +64,9 @@ class GraphCaptureContext:
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
# use int value instead of ReduceOp.SUM to support torch compile
|
||||
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
||||
@@ -489,9 +492,7 @@ class GroupCoordinator:
|
||||
|
||||
if input_.is_cpu:
|
||||
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
||||
torch.ops.sgl_kernel.shm_allreduce(
|
||||
input_, torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
@@ -49,6 +49,9 @@ class IntelAMXAttnBackend(AttentionBackend):
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
self.forward_metadata = (attn_logits, max_extend_len)
|
||||
|
||||
def get_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q,
|
||||
|
||||
@@ -352,6 +352,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
_is_cpu_amx_available
|
||||
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
||||
_amx_process_weight_after_loading(layer, ["weight"])
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
layer.weight_scale_inv.data, requires_grad=False
|
||||
)
|
||||
return
|
||||
else:
|
||||
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||
|
||||
@@ -343,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
||||
_is_cpu_amx_available
|
||||
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
||||
_amx_process_weight_after_loading(layer, ["weight"])
|
||||
return
|
||||
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
else:
|
||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
@@ -486,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
_is_cpu_amx_available
|
||||
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
||||
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||
return
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
@@ -414,7 +414,7 @@ class Scheduler(
|
||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||
f"max_running_requests={self.max_running_requests}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
f"available_gpu_mem={avail_mem:.2f} GB"
|
||||
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
@@ -2252,10 +2252,9 @@ class Scheduler(
|
||||
"token_capacity": int(self.max_total_num_tokens),
|
||||
}
|
||||
|
||||
if not _is_cpu:
|
||||
ret["memory_usage"]["cuda_graph"] = round(
|
||||
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
|
||||
)
|
||||
ret["memory_usage"]["graph"] = round(
|
||||
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
||||
)
|
||||
|
||||
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
||||
ret["avg_spec_accept_length"] = (
|
||||
|
||||
@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
|
||||
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
||||
|
||||
msg += (
|
||||
f"cuda graph: {can_run_cuda_graph}, "
|
||||
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
|
||||
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}, "
|
||||
)
|
||||
|
||||
640
python/sglang/srt/model_executor/cpu_graph_runner.py
Normal file
640
python/sglang/srt/model_executor/cpu_graph_runner.py
Normal file
@@ -0,0 +1,640 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Run the model with cpu torch compile."""
|
||||
|
||||
# The implementation of CPUGraphRunner follows the CudaGraphRunner
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
log_info_on_rank0,
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
require_mlp_sync,
|
||||
require_mlp_tp_gather,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_model(
|
||||
model: torch.nn.Module,
|
||||
enable_compile: bool,
|
||||
num_tokens: int,
|
||||
tp_group: GroupCoordinator,
|
||||
):
|
||||
"""Patch the model to make it compatible with torch.compile"""
|
||||
backup_ca_comm = None
|
||||
|
||||
try:
|
||||
if enable_compile:
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
# Use custom-allreduce here.
|
||||
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
||||
# even with ENABLE_INTRA_NODE_COMM=1.
|
||||
# tp_group.ca_comm = None
|
||||
yield torch.compile(
|
||||
torch.no_grad()(model.forward),
|
||||
dynamic=False,
|
||||
)
|
||||
else:
|
||||
yield model.forward
|
||||
finally:
|
||||
if enable_compile:
|
||||
tp_group.ca_comm = backup_ca_comm
|
||||
|
||||
|
||||
def set_torch_compile_config():
|
||||
import torch._dynamo.config
|
||||
import torch._inductor.config
|
||||
|
||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||
torch._inductor.config.freezing = True
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
||||
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
||||
torch._dynamo.config.cache_size_limit = 1024
|
||||
monkey_patch_torch_compile()
|
||||
|
||||
|
||||
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
server_args = model_runner.server_args
|
||||
# cpu torch compile only speeds up decoding by
|
||||
# reducing python overhead when bs is small
|
||||
capture_bs = list(range(1, 17))
|
||||
capture_bs = [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
||||
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
||||
capture_bs = list(sorted(set(capture_bs)))
|
||||
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
||||
return capture_bs
|
||||
|
||||
|
||||
def register_fake_ops():
|
||||
"""
|
||||
Registers fake/meta implementations for all custom sgl_kernel CPU operators
|
||||
using torch.library.register_fake to support torch.compile
|
||||
"""
|
||||
|
||||
none_return_ops = [
|
||||
"shm_allreduce",
|
||||
"bmm_cpu",
|
||||
"fused_add_rmsnorm_cpu",
|
||||
"decode_attention_cpu",
|
||||
"extend_attention_cpu",
|
||||
]
|
||||
for op in none_return_ops:
|
||||
|
||||
@torch.library.register_fake(f"sgl_kernel::{op}")
|
||||
def _(*args, **kwargs):
|
||||
return
|
||||
|
||||
for op in [
|
||||
"rmsnorm_cpu",
|
||||
"l2norm_cpu",
|
||||
"fused_experts_cpu",
|
||||
"shared_expert_cpu",
|
||||
]:
|
||||
|
||||
@torch.library.register_fake(f"sgl_kernel::{op}")
|
||||
def _(input, *args, **kwargs):
|
||||
return torch.empty_like(input)
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::qkv_proj_with_rope")
|
||||
def _(
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
q_b_proj_weight,
|
||||
kv_a_proj_weight,
|
||||
w_kc,
|
||||
q_a_layernorm_weight,
|
||||
kv_a_layernorm_weight,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
use_int8_w8a8,
|
||||
use_fp8_w8a16,
|
||||
q_a_proj_scale,
|
||||
q_b_proj_scale,
|
||||
kv_a_proj_scale,
|
||||
is_vnni,
|
||||
block_size,
|
||||
):
|
||||
num_seqs = hidden_states.shape[0]
|
||||
num_heads = w_kc.shape[0]
|
||||
kv_lora_rank = w_kc.shape[1]
|
||||
qk_rope_head_dim = kv_a_proj_weight.shape[0] - kv_lora_rank
|
||||
q_input = torch.empty(
|
||||
num_seqs,
|
||||
num_heads,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
k_input = torch.empty(
|
||||
num_seqs,
|
||||
1,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
v_input = k_input.narrow(-1, 0, kv_lora_rank)
|
||||
return q_input, k_input, v_input
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::rotary_embedding_cpu")
|
||||
def _(positions, query, key, head_size, cos_sin_cache, is_neox):
|
||||
if query.ndim == 2:
|
||||
return query, key
|
||||
else:
|
||||
return torch.empty_like(query), torch.empty_like(key)
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::qkv_proj_with_rope_fused_weight")
|
||||
def _(
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
q_b_proj_weight,
|
||||
w_kc,
|
||||
q_a_layernorm_weight,
|
||||
kv_a_layernorm_weight,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
use_int8_w8a8,
|
||||
use_fp8_w8a16,
|
||||
qkv_a_proj_scale,
|
||||
q_b_proj_scale,
|
||||
is_vnni,
|
||||
block_size,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
):
|
||||
num_seqs = hidden_states.shape[0]
|
||||
num_heads = w_kc.shape[0]
|
||||
kv_lora_rank = w_kc.shape[1]
|
||||
weight_chunks = torch.split(
|
||||
q_a_proj_weight, [q_lora_rank, kv_lora_rank + qk_rope_head_dim], dim=0
|
||||
)
|
||||
qk_rope_head_dim = weight_chunks[1].shape[0] - kv_lora_rank
|
||||
q_input = torch.empty(
|
||||
num_seqs,
|
||||
num_heads,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
k_input = torch.empty(
|
||||
num_seqs,
|
||||
1,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
v_input = k_input.narrow(-1, 0, kv_lora_rank)
|
||||
return q_input, k_input, v_input
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::weight_packed_linear")
|
||||
def _(x, weight, bias, is_vnni):
|
||||
return x.new_empty(x.shape[0], weight.shape[0])
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::per_token_quant_int8_cpu")
|
||||
def _(input):
|
||||
M = input.shape[0]
|
||||
K = input.shape[1]
|
||||
Aq = input.new_empty(M, K, dtype=torch.int8)
|
||||
As = input.new_empty(M, dtype=torch.float32)
|
||||
return Aq, As
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::int8_scaled_mm_cpu")
|
||||
def _(mat1, mat2, scales1, scales2, bias, out_dtype, is_vnni):
|
||||
M = mat1.shape[0]
|
||||
N = mat2.shape[0]
|
||||
out = mat1.new_empty(M, N, dtype=out_dtype)
|
||||
return out
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::grouped_topk_cpu")
|
||||
def _(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
num_token_non_padded,
|
||||
):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
shape = (num_tokens, topk)
|
||||
device = hidden_states.device
|
||||
topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
|
||||
topk_ids = torch.empty(shape, device=device, dtype=torch.int)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::biased_grouped_topk_cpu")
|
||||
def _(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
num_token_non_padded,
|
||||
):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
shape = (num_tokens, topk)
|
||||
device = hidden_states.device
|
||||
topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
|
||||
topk_ids = torch.empty(shape, device=device, dtype=torch.int)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::topk_sigmoid_cpu")
|
||||
def _(hidden_states, gating_output, topk, renormalize):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
shape = (num_tokens, topk)
|
||||
return (
|
||||
torch.empty(shape, device=hidden_states.device, dtype=torch.float),
|
||||
torch.empty(shape, device=hidden_states.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::topk_softmax_cpu")
|
||||
def _(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
shape = (num_tokens, topk)
|
||||
return (
|
||||
torch.empty(shape, device=hidden_states.device, dtype=torch.float),
|
||||
torch.empty(shape, device=hidden_states.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::silu_and_mul_cpu")
|
||||
def _(input):
|
||||
return input.new_empty(input.shape[0], input.shape[1] // 2)
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::int8_scaled_mm_with_quant")
|
||||
def _(
|
||||
mat1,
|
||||
mat2,
|
||||
scales2,
|
||||
bias,
|
||||
out_dtype,
|
||||
is_vnni,
|
||||
):
|
||||
M = mat1.shape[0]
|
||||
N = mat2.shape[0]
|
||||
return mat1.new_empty(M, N, dtype=out_dtype)
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::fp8_scaled_mm_cpu")
|
||||
def _(
|
||||
mat1,
|
||||
mat2,
|
||||
scales2,
|
||||
block_size,
|
||||
bias,
|
||||
out_dtype,
|
||||
is_vnni,
|
||||
):
|
||||
M = mat1.shape[0]
|
||||
N = mat2.shape[0]
|
||||
return mat1.new_empty(M, N, dtype=out_dtype)
|
||||
|
||||
|
||||
# TODO Remove unnecessary settings for CPUGraphRunner.
|
||||
# Re-abstract the graph runner and restructure CPUGraphRunner to reuse the same logic.
|
||||
class CPUGraphRunner:
|
||||
"""A CPUGraphRunner runs the forward pass of a model with cpu torch.compile."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
# Parse args
|
||||
self.model_runner = model_runner
|
||||
self.device = model_runner.device
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
||||
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
||||
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
||||
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
||||
self.enable_two_batch_overlap = (
|
||||
model_runner.server_args.enable_two_batch_overlap
|
||||
)
|
||||
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
||||
self.enable_profile_cuda_graph = (
|
||||
model_runner.server_args.enable_profile_cuda_graph
|
||||
)
|
||||
self.tp_size = model_runner.server_args.tp_size
|
||||
self.dp_size = model_runner.server_args.dp_size
|
||||
self.pp_size = model_runner.server_args.pp_size
|
||||
|
||||
self.capture_forward_mode = ForwardMode.DECODE
|
||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||
self.num_tokens_per_bs = 1
|
||||
|
||||
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
||||
if model_runner.server_args.enable_return_hidden_states:
|
||||
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
|
||||
assert (
|
||||
not self.model_runner.server_args.enable_lora
|
||||
), "CPUGraphRunner does not support LoRA yet."
|
||||
assert (
|
||||
not self.enable_two_batch_overlap
|
||||
), "CPUGraphRunner does not support two batch overlap yet."
|
||||
assert (
|
||||
not self.require_mlp_tp_gather
|
||||
), "CPUGraphRunner does not support MLP TP gather yet."
|
||||
assert (
|
||||
not self.require_mlp_sync
|
||||
), "CPUGraphRunner does not support MLP sync yet."
|
||||
assert (
|
||||
not self.require_gathered_buffer
|
||||
), "CPUGraphRunner does not support gathered buffer yet."
|
||||
assert (
|
||||
model_runner.spec_algorithm == SpeculativeAlgorithm.NONE
|
||||
), "CPUGraphRunner does not support speculative inference yet."
|
||||
# TODO add compile support for encoder-decoder models
|
||||
assert (
|
||||
not self.is_encoder_decoder
|
||||
), "CPUGraphRunner does not support encoder-decoder models yet."
|
||||
assert self.dp_size == 1, "CPUGraphRunner does not support DP yet."
|
||||
assert self.pp_size == 1, "CPUGraphRunner does not support PP yet."
|
||||
|
||||
# Batch sizes to capture
|
||||
self.capture_bs = get_batch_sizes_to_capture(model_runner)
|
||||
log_info_on_rank0(logger, f"Capture cpu graph bs {self.capture_bs}")
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
if self.enable_torch_compile:
|
||||
register_fake_ops()
|
||||
set_torch_compile_config()
|
||||
|
||||
# Graph inputs
|
||||
with torch.device(self.device):
|
||||
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int64)
|
||||
self.seq_lens = torch.full(
|
||||
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int64
|
||||
)
|
||||
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int64)
|
||||
self.custom_mask = torch.ones(
|
||||
(
|
||||
(self.seq_lens.sum().item() + self.max_num_token)
|
||||
* self.num_tokens_per_bs
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture CPU graph failed: {e}\n{CPU_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
is_bs_supported = forward_batch.batch_size in self.graphs
|
||||
|
||||
requested_capture_hidden_mode = max(
|
||||
forward_batch.capture_hidden_mode,
|
||||
(
|
||||
forward_batch.spec_info.capture_hidden_mode
|
||||
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
||||
is not None
|
||||
else CaptureHiddenMode.NULL
|
||||
),
|
||||
)
|
||||
capture_hidden_mode_matches = (
|
||||
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
||||
or requested_capture_hidden_mode == self.capture_hidden_mode
|
||||
)
|
||||
|
||||
return is_bs_supported and capture_hidden_mode_matches
|
||||
|
||||
def capture(self) -> None:
|
||||
capture_range = (
|
||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else reversed(self.capture_bs)
|
||||
)
|
||||
for bs in capture_range:
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
avail_mem = psutil.virtual_memory().available / (1 << 30)
|
||||
capture_range.set_description(
|
||||
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
||||
)
|
||||
|
||||
with patch_model(
|
||||
self.model_runner.model,
|
||||
bs in self.capture_bs,
|
||||
num_tokens=bs * self.num_tokens_per_bs,
|
||||
tp_group=self.model_runner.tp_group,
|
||||
) as forward:
|
||||
(
|
||||
graph,
|
||||
output_buffers,
|
||||
) = self.capture_one_batch_size(bs, forward)
|
||||
self.graphs[bs] = graph
|
||||
self.output_buffers[bs] = output_buffers
|
||||
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
num_tokens = bs * self.num_tokens_per_bs
|
||||
|
||||
# Graph inputs
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
self.num_token_non_padded[...] = num_tokens
|
||||
|
||||
spec_info = self.get_spec_info(num_tokens)
|
||||
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
||||
self.capture_hidden_mode = (
|
||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||
)
|
||||
|
||||
forward_batch = ForwardBatch(
|
||||
forward_mode=self.capture_forward_mode,
|
||||
batch_size=bs,
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens.sum().item(),
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
spec_algorithm=self.model_runner.spec_algorithm,
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
num_token_non_padded=self.num_token_non_padded,
|
||||
global_forward_mode=self.capture_forward_mode,
|
||||
)
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
# Do infernence to avoid setting attr at runtime, e.g.,
|
||||
# self.attn_mha.kv_b_proj = self.kv_b_proj for full graph compile on CPU
|
||||
self.model_runner.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
)
|
||||
|
||||
# Run and capture
|
||||
def run_once():
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
logits_output_or_pp_proxy_tensors = forward(
|
||||
input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
)
|
||||
return logits_output_or_pp_proxy_tensors
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(2):
|
||||
self.model_runner.tp_group.barrier()
|
||||
out = run_once()
|
||||
return forward, out
|
||||
|
||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||
|
||||
# If the required capture_hidden_mode changes, we need to recapture the graph
|
||||
|
||||
# These are the different factors that can influence the capture_hidden_mode
|
||||
capture_hidden_mode_required_by_forward_batch = (
|
||||
forward_batch.capture_hidden_mode
|
||||
)
|
||||
capture_hidden_mode_required_by_spec_info = getattr(
|
||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
capture_hidden_mode_required_for_returning_hidden_states = (
|
||||
CaptureHiddenMode.FULL
|
||||
if self.model_runner.server_args.enable_return_hidden_states
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
|
||||
# Determine the highest capture_hidden_mode required
|
||||
# (If we have FULL, we can emulate LAST or NULL)
|
||||
# (If we have LAST, we can emulate NULL)
|
||||
required_capture_hidden_mode = max(
|
||||
capture_hidden_mode_required_by_forward_batch,
|
||||
capture_hidden_mode_required_by_spec_info,
|
||||
capture_hidden_mode_required_for_returning_hidden_states,
|
||||
)
|
||||
|
||||
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
||||
if self.capture_hidden_mode != required_capture_hidden_mode:
|
||||
self.capture_hidden_mode = required_capture_hidden_mode
|
||||
self.capture()
|
||||
|
||||
# TODO add padding support for CPUGraphRunner
|
||||
def replay(
|
||||
self,
|
||||
forward_batch: ForwardBatch,
|
||||
skip_attn_backend_init: bool = False,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
||||
assert (
|
||||
pp_proxy_tensors is None
|
||||
), "PPProxyTensors is not supported in CPUGraphRunner yet."
|
||||
self.recapture_if_needed(forward_batch)
|
||||
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
output = self.graphs[forward_batch.batch_size](
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
)
|
||||
return output
|
||||
|
||||
def get_spec_info(self, num_tokens: int):
|
||||
spec_info = None
|
||||
if self.model_runner.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||
|
||||
if self.model_runner.is_draft_worker:
|
||||
raise RuntimeError("This should not happen.")
|
||||
else:
|
||||
spec_info = EagleVerifyInput(
|
||||
draft_token=None,
|
||||
custom_mask=self.custom_mask,
|
||||
positions=None,
|
||||
retrive_index=None,
|
||||
retrive_next_token=None,
|
||||
retrive_next_sibling=None,
|
||||
retrive_cum_len=None,
|
||||
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
||||
topk=self.model_runner.server_args.speculative_eagle_topk,
|
||||
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
seq_lens_sum=None,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
return spec_info
|
||||
|
||||
|
||||
CPU_GRAPH_CAPTURE_FAILED_MSG = (
|
||||
"Possible solutions:\n"
|
||||
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
||||
"2. set --torch-compile-max-bs to a smaller value (e.g., 8)\n"
|
||||
"3. disable torch compile by not using --enable-torch-compile\n"
|
||||
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
||||
)
|
||||
@@ -132,6 +132,9 @@ class ForwardMode(IntEnum):
|
||||
or self == ForwardMode.IDLE
|
||||
)
|
||||
|
||||
def is_cpu_graph(self):
|
||||
return self == ForwardMode.DECODE
|
||||
|
||||
def is_dummy_first(self):
|
||||
return self == ForwardMode.DUMMY_FIRST
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -89,6 +90,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
||||
@@ -360,12 +362,12 @@ class ModelRunner:
|
||||
self.init_cublas()
|
||||
self.init_attention_backend()
|
||||
self.init_device_graphs()
|
||||
elif self.device == "npu":
|
||||
elif self.device in ["npu", "cpu"]:
|
||||
self.init_attention_backend()
|
||||
self.init_device_graphs()
|
||||
else:
|
||||
self.graph_runner = None
|
||||
self.cuda_graph_mem_usage = 0
|
||||
self.graph_mem_usage = 0
|
||||
self.init_attention_backend()
|
||||
|
||||
# auxiliary hidden capture mode. TODO: expose this to server args?
|
||||
@@ -608,6 +610,11 @@ class ModelRunner:
|
||||
# Set local size to hint SGLang to use shared memory based AllReduce
|
||||
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
||||
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
||||
|
||||
@torch.library.register_fake("sgl_kernel::shm_allgather")
|
||||
def _(data, dim):
|
||||
return torch.cat([data] * self.tp_size, dim=dim)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
||||
@@ -1619,30 +1626,39 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def init_device_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
"""Capture device graphs."""
|
||||
self.graph_runner = None
|
||||
self.cuda_graph_mem_usage = 0
|
||||
self.graph_mem_usage = 0
|
||||
|
||||
if not self.is_generation:
|
||||
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
||||
return
|
||||
|
||||
if self.server_args.disable_cuda_graph:
|
||||
if self.device != "cpu" and self.server_args.disable_cuda_graph:
|
||||
return
|
||||
|
||||
if self.device == "cpu" and not self.server_args.enable_torch_compile:
|
||||
return
|
||||
|
||||
tic = time.perf_counter()
|
||||
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
||||
)
|
||||
self.graph_runner = (
|
||||
CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self)
|
||||
graph_runners = defaultdict(
|
||||
lambda: CudaGraphRunner,
|
||||
{
|
||||
"cpu": CPUGraphRunner,
|
||||
"npu": NPUGraphRunner,
|
||||
},
|
||||
)
|
||||
self.graph_runner = graph_runners[self.device](self)
|
||||
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
self.cuda_graph_mem_usage = before_mem - after_mem
|
||||
self.graph_mem_usage = before_mem - after_mem
|
||||
logger.info(
|
||||
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
||||
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
||||
f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
)
|
||||
|
||||
def init_threads_binding(self):
|
||||
@@ -1787,18 +1803,24 @@ class ModelRunner:
|
||||
reinit_attn_backend: bool = False,
|
||||
split_forward_count: int = 1,
|
||||
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||
can_run_cuda_graph = bool(
|
||||
forward_batch.forward_mode.is_cuda_graph()
|
||||
mode_check = (
|
||||
forward_batch.forward_mode.is_cpu_graph
|
||||
if self.device == "cpu"
|
||||
else forward_batch.forward_mode.is_cuda_graph
|
||||
)
|
||||
can_run_graph = bool(
|
||||
mode_check()
|
||||
and self.graph_runner
|
||||
and self.graph_runner.can_run(forward_batch)
|
||||
)
|
||||
if can_run_cuda_graph:
|
||||
|
||||
if can_run_graph:
|
||||
ret = self.graph_runner.replay(
|
||||
forward_batch,
|
||||
skip_attn_backend_init=skip_attn_backend_init,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
return ret, can_run_cuda_graph
|
||||
return ret, can_run_graph
|
||||
|
||||
# For MLP sync
|
||||
if forward_batch.global_num_tokens_cpu is not None:
|
||||
@@ -1833,7 +1855,7 @@ class ModelRunner:
|
||||
):
|
||||
forward_batch.post_forward_mlp_sync_batch(ret)
|
||||
|
||||
return ret, can_run_cuda_graph
|
||||
return ret, can_run_graph
|
||||
|
||||
def _preprocess_logits(
|
||||
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||
|
||||
@@ -230,8 +230,16 @@ except:
|
||||
is_intel_amx_backend_available = False
|
||||
|
||||
|
||||
try:
|
||||
# move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support
|
||||
# to support torch compile
|
||||
is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported()
|
||||
except:
|
||||
is_amx_tile_supported = False
|
||||
|
||||
|
||||
def cpu_has_amx_support():
|
||||
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
||||
return is_amx_tile_supported and is_intel_amx_backend_available
|
||||
|
||||
|
||||
def use_intel_amx_backend(layer):
|
||||
|
||||
@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
|
||||
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||
|
||||
// topk
|
||||
@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
// decode
|
||||
m.def(
|
||||
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
|
||||
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, "
|
||||
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
|
||||
"float logit_cap) -> ()");
|
||||
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
||||
|
||||
// extend
|
||||
m.def(
|
||||
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
|
||||
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, "
|
||||
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
|
||||
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
|
||||
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
|
||||
@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
|
||||
|
||||
// bmm
|
||||
m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
|
||||
m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
|
||||
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
|
||||
|
||||
// moe
|
||||
@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
// all reduce
|
||||
m.def("initialize(int size, int rank) -> ()");
|
||||
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
|
||||
m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()");
|
||||
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
|
||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||
|
||||
@@ -276,6 +276,7 @@ suite_xeon = {
|
||||
TestFile("cpu/test_shared_expert.py"),
|
||||
TestFile("cpu/test_topk.py"),
|
||||
TestFile("test_intel_amx_attention_backend.py"),
|
||||
TestFile("test_cpu_graph.py"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
87
test/srt/test_cpu_graph.py
Normal file
87
test/srt/test_cpu_graph.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from test_intel_amx_attention_backend import intel_amx_benchmark
|
||||
|
||||
from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestCPUGraph(CustomTestCase):
|
||||
|
||||
@intel_amx_benchmark(
|
||||
extra_args=[
|
||||
"--batch-size",
|
||||
"1",
|
||||
"--mem-fraction-static",
|
||||
"0.05",
|
||||
"--enable-torch-compile",
|
||||
"--torch-compile-max-bs",
|
||||
"1",
|
||||
],
|
||||
min_throughput=10,
|
||||
)
|
||||
def test_latency_torch_compile_cpu(self):
|
||||
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
|
||||
def test_mmlu_torch_compile_cpu(self):
|
||||
model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
cpu_ids_by_node = get_cpu_ids_by_node()
|
||||
n_numa_node = len(cpu_ids_by_node)
|
||||
env = copy.deepcopy(os.environ)
|
||||
env["SGLANG_CPU_OMP_THREADS_BIND"] = "all"
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--attention-backend",
|
||||
"intel_amx",
|
||||
"--mem-fraction-static",
|
||||
"0.05",
|
||||
"--disable-radix",
|
||||
"--trust-remote-code",
|
||||
"--disable-overlap-schedule",
|
||||
"--enable-torch-compile",
|
||||
"--torch-compile-max-bs",
|
||||
"1",
|
||||
"--tp",
|
||||
f"{n_numa_node}",
|
||||
],
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
if is_in_ci():
|
||||
self.assertGreater(metrics["score"], 0.45)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -3,7 +3,6 @@ Usage:
|
||||
python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu
|
||||
"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from types import SimpleNamespace
|
||||
@@ -35,8 +34,6 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
|
||||
"intel_amx",
|
||||
"--disable-radix",
|
||||
"--trust-remote-code",
|
||||
"--batch-size",
|
||||
"4",
|
||||
]
|
||||
full_args = common_args + (extra_args or [])
|
||||
|
||||
@@ -60,28 +57,33 @@ def intel_amx_benchmark(extra_args=None, min_throughput=None):
|
||||
|
||||
class TestIntelAMXAttnBackend(CustomTestCase):
|
||||
|
||||
@intel_amx_benchmark(min_throughput=10)
|
||||
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=10)
|
||||
def test_latency_mla_model(self):
|
||||
return DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
|
||||
@intel_amx_benchmark(min_throughput=40)
|
||||
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=40)
|
||||
def test_latency_default_model(self):
|
||||
return DEFAULT_MODEL_NAME_FOR_TEST
|
||||
|
||||
@intel_amx_benchmark(min_throughput=150)
|
||||
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=150)
|
||||
def test_latency_fp8_qwen(self):
|
||||
return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8
|
||||
|
||||
@intel_amx_benchmark(min_throughput=50)
|
||||
@intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=50)
|
||||
def test_latency_fp8_moe_model(self):
|
||||
return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE
|
||||
|
||||
@intel_amx_benchmark(extra_args=["--quantization", "w8a8_int8"], min_throughput=100)
|
||||
@intel_amx_benchmark(
|
||||
extra_args=["--batch-size", "4", "--quantization", "w8a8_int8"],
|
||||
min_throughput=100,
|
||||
)
|
||||
def test_latency_w8a8_default_model(self):
|
||||
return DEFAULT_MODEL_NAME_FOR_TEST_W8A8
|
||||
|
||||
@intel_amx_benchmark(
|
||||
extra_args=[
|
||||
"--batch-size",
|
||||
"4",
|
||||
"--quantization",
|
||||
"w8a8_int8",
|
||||
"--mem-fraction-static",
|
||||
|
||||
Reference in New Issue
Block a user