Update python API of activation, topk, norm and rope and remove vllm dependency (#6614)
Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: sdp <sdp@gnr799219.jf.intel.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_cpu = is_cpu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
@@ -75,5 +77,7 @@ class CustomOp(nn.Module):
|
||||
return self.forward_cuda
|
||||
elif _is_hip:
|
||||
return self.forward_hip
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
return self.forward_cpu
|
||||
else:
|
||||
return self.forward_native
|
||||
|
||||
@@ -29,11 +29,19 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from sglang.utils import resolve_obj_by_qualname
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
@@ -53,6 +61,15 @@ class SiluAndMul(CustomOp):
|
||||
silu_and_mul(x, out)
|
||||
return out
|
||||
|
||||
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if _is_cpu_amx_available:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||
return out
|
||||
else:
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
class GeluAndMul(CustomOp):
|
||||
def __init__(self, approximate="tanh"):
|
||||
@@ -185,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
if not _is_cuda and not _is_npu:
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
@@ -20,12 +20,21 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -122,6 +131,23 @@ class RMSNorm(CustomOp):
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if _is_cpu_amx_available:
|
||||
if residual is not None:
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
return x, residual
|
||||
return torch.ops.sgl_kernel.rmsnorm_cpu(
|
||||
x, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
else:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
def __init__(
|
||||
@@ -188,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
if not (_is_cuda or _is_hip or _is_npu):
|
||||
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
logger.info(
|
||||
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
||||
)
|
||||
|
||||
@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
|
||||
sglang_per_token_group_quant_int8,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
get_device_name,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
log_info_on_rank0,
|
||||
@@ -36,9 +38,13 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
pass
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
@@ -241,7 +241,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
@@ -260,7 +264,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
forward_native = forward_cuda
|
||||
forward_native = forward_cpu
|
||||
|
||||
|
||||
class FusedMoE(torch.nn.Module):
|
||||
|
||||
@@ -28,10 +28,18 @@ from sglang.srt.managers.expert_location_dispatch import (
|
||||
topk_ids_logical_to_physical,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_compiler_backend,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import moe_fused_gate
|
||||
@@ -40,7 +48,7 @@ if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
def fused_topk_torch_native(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
@@ -61,6 +69,20 @@ def fused_topk_native(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_topk_cpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
return torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -115,7 +137,7 @@ def _fused_topk_postprocess(
|
||||
|
||||
# This is used by the Deepseek V2/V3/R1 series models
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def grouped_topk(
|
||||
def grouped_topk_gpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
@@ -171,6 +193,32 @@ def grouped_topk(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def grouped_topk_cpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert expert_location_dispatch_info is None
|
||||
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
num_token_non_padded,
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -258,7 +306,7 @@ def _biased_grouped_topk_postprocess(
|
||||
return topk_ids
|
||||
|
||||
|
||||
def biased_grouped_topk(
|
||||
def biased_grouped_topk_gpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
@@ -322,6 +370,45 @@ def biased_grouped_topk(
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_cpu(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
compiled: bool = True,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert expert_location_dispatch_info is None
|
||||
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
num_token_non_padded,
|
||||
)
|
||||
|
||||
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
biased_grouped_topk = biased_grouped_topk_cpu
|
||||
grouped_topk = grouped_topk_cpu
|
||||
fused_topk_native = fused_topk_cpu
|
||||
else:
|
||||
biased_grouped_topk = biased_grouped_topk_gpu
|
||||
grouped_topk = grouped_topk_gpu
|
||||
fused_topk_native = fused_topk_torch_native
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
|
||||
@@ -14,15 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
cpu_has_amx_support,
|
||||
per_tensor_dequantize,
|
||||
replace_parameter,
|
||||
)
|
||||
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
|
||||
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if not _is_cuda and not _is_npu:
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
@@ -64,7 +64,9 @@ from sglang.srt.layers.quantization.utils import (
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
@@ -76,6 +78,8 @@ from sglang.srt.utils import (
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
@@ -88,7 +92,7 @@ if _is_hip:
|
||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
if not _is_cuda and not _is_npu:
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
|
||||
@@ -6,12 +6,14 @@ from typing import List, Mapping, Tuple, Union
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda, is_npu
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if not _is_cuda and not _is_npu:
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
|
||||
@@ -8,11 +8,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda, is_hip, is_npu
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||
@@ -85,7 +87,9 @@ class RotaryEmbedding(CustomOp):
|
||||
if not _is_cuda:
|
||||
cache = cache.to(dtype)
|
||||
|
||||
if not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]:
|
||||
if (
|
||||
not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
|
||||
) and not (_is_cpu and _is_cpu_amx_available):
|
||||
from vllm._custom_ops import rotary_embedding
|
||||
|
||||
self.vllm_rotary_embedding = rotary_embedding
|
||||
@@ -148,6 +152,26 @@ class RotaryEmbedding(CustomOp):
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
positions = torch.add(positions, offsets) if offsets is not None else positions
|
||||
if _is_cpu_amx_available:
|
||||
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
)
|
||||
else:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -697,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
key = key_rot
|
||||
return query.to(dtype), key.to(dtype)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
positions = torch.add(positions, offsets) if offsets is not None else positions
|
||||
if _is_cpu_amx_available:
|
||||
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
positions, query, key, self.head_size, self.cos_sin_cache, False
|
||||
)
|
||||
else:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
|
||||
class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
|
||||
@@ -111,6 +111,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
|
||||
# Use a small KV cache pool size for tests in CI
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
@@ -302,7 +303,7 @@ class ModelRunner:
|
||||
if (
|
||||
server_args.attention_backend == "intel_amx"
|
||||
and server_args.device == "cpu"
|
||||
and not cpu_has_amx_support()
|
||||
and not _is_cpu_amx_available
|
||||
):
|
||||
logger.info(
|
||||
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
||||
|
||||
@@ -72,7 +72,7 @@ from sglang.srt.layers.quantization.int8_utils import (
|
||||
block_dequant as int8_block_dequant,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -95,8 +95,10 @@ from sglang.srt.utils import (
|
||||
LazyValue,
|
||||
add_prefix,
|
||||
bind_or_assign,
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_non_idle_and_non_empty,
|
||||
@@ -107,9 +109,13 @@ _is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
elif _is_cpu and _is_cpu_amx_available:
|
||||
pass
|
||||
else:
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
@@ -665,13 +671,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.rotary_emb = get_rope_wrapper(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
|
||||
@@ -160,7 +160,7 @@ def is_npu() -> bool:
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||
|
||||
|
||||
def is_cpu() -> bool:
|
||||
def is_host_cpu_x86() -> bool:
|
||||
machine = platform.machine().lower()
|
||||
return (
|
||||
machine in ("x86_64", "amd64", "i386", "i686")
|
||||
@@ -169,6 +169,10 @@ def is_cpu() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_cpu() -> bool:
|
||||
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
|
||||
|
||||
|
||||
def is_flashinfer_available():
|
||||
"""
|
||||
Check whether flashinfer is available.
|
||||
@@ -1452,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str:
|
||||
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
|
||||
)
|
||||
|
||||
if is_cpu():
|
||||
if cpu_has_amx_support():
|
||||
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
|
||||
else:
|
||||
logger.warning(
|
||||
"CPU device enabled, using torch native backend, low performance expected."
|
||||
)
|
||||
return "cpu"
|
||||
|
||||
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user