Update python API of activation, topk, norm and rope and remove vllm dependency (#6614)
Co-authored-by: Wu, Chunyuan <chunyuan.wu@intel.com> Co-authored-by: jianan-gu <jianan.gu@intel.com> Co-authored-by: sdp <sdp@gnr799219.jf.intel.com>
This commit is contained in:
@@ -29,11 +29,19 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from sglang.utils import resolve_obj_by_qualname
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
@@ -53,6 +61,15 @@ class SiluAndMul(CustomOp):
|
||||
silu_and_mul(x, out)
|
||||
return out
|
||||
|
||||
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if _is_cpu_amx_available:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||
return out
|
||||
else:
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
class GeluAndMul(CustomOp):
|
||||
def __init__(self, approximate="tanh"):
|
||||
@@ -185,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
if not _is_cuda and not _is_npu:
|
||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
||||
)
|
||||
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
|
||||
|
||||
Reference in New Issue
Block a user