From 45de89719c3360ccadc5168b882b7bc174acac2f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 12 Mar 2025 23:45:52 -0700 Subject: [PATCH] Revert "[XPU][CPU] Enable the native path of DeepSeek" (#4367) --- docs/start/install.md | 9 - python/pyproject.toml | 2 +- .../srt/layers/quantization/base_config.py | 14 +- .../srt/layers/quantization/blockwise_int8.py | 18 +- python/sglang/srt/layers/quantization/fp8.py | 95 +------ .../srt/layers/quantization/fp8_kernel.py | 263 ++++++------------ python/sglang/srt/layers/quantization/gptq.py | 47 +--- .../srt/layers/quantization/modelopt_quant.py | 17 +- .../srt/layers/quantization/w8a8_fp8.py | 17 +- .../srt/layers/quantization/w8a8_int8.py | 17 +- python/sglang/srt/layers/rotary_embedding.py | 88 ++---- .../sglang/srt/model_executor/model_runner.py | 2 - python/sglang/srt/model_loader/loader.py | 20 +- python/sglang/srt/models/deepseek_v2.py | 15 +- python/sglang/srt/utils.py | 9 - python/sglang/test/test_block_fp8.py | 87 +++++- 16 files changed, 221 insertions(+), 499 deletions(-) diff --git a/docs/start/install.md b/docs/start/install.md index f7fd11141..3b357fa3e 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -52,15 +52,6 @@ cd .. pip install -e "python[all_hip]" ``` -Note: To Intel GPU, do following instead: - -``` -git clone https://github.com/sgl-project/sglang.git -cd sglang -pip install --upgrade pip -pip install -e "python[all_xpu]" -``` - ## Method 3: Using docker The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). diff --git a/python/pyproject.toml b/python/pyproject.toml index e81edaa28..a0ad844bd 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -57,7 +57,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1 # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm -srt_xpu = ["sglang[runtime_common]", "vllm>=0.6.4.post1,<=0.7.2", "outlines>=0.0.44,<=0.1.11"] +srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"] # For Intel Gaudi(device : hpu) follow the installation guide # https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 63c9a6874..e45dda2cc 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -54,18 +54,8 @@ class QuantizationConfig(ABC): """Minimum GPU capability to support the quantization method. E.g., 70 for Volta, 75 for Turing, 80 for Ampere. - This requirement is due to the custom kernels used by the - quantization method or the stock pytorch capability. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def get_availability(cls) -> bool: - """Whether the quantization config is available on current device. - - This requirement is due to the custom kernels used by the - quantization method or the stock pytorch capability. + This requirement is due to the custom CUDA kernels used by the + quantization method. """ raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index 4c633b328..ce526cd6a 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -import sys from typing import Any, Callable, Dict, List, Optional import torch @@ -20,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear -from sglang.srt.utils import get_device_capability, set_weight_attrs +from sglang.srt.utils import set_weight_attrs ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -72,20 +71,7 @@ class BlockInt8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 80 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> bool: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 80 - - # Vendors can update - return False + return 80 @classmethod def get_config_filenames(cls) -> List[str]: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index a39149a2b..bff0fe96e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -import sys from typing import Any, Callable, Dict, List, Optional import torch @@ -37,11 +36,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_kernel import ( - native_per_token_group_quant_fp8, - native_w8a8_block_fp8_matmul, - per_token_group_quant_fp8, -) +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, apply_w8a8_block_fp8_linear, @@ -51,8 +46,6 @@ from sglang.srt.layers.quantization.fp8_utils import ( ) from sglang.srt.utils import ( get_bool_env_var, - get_device_capability, - is_cuda, is_hip, permute_weight, print_warning_once, @@ -62,7 +55,6 @@ from sglang.srt.utils import ( ACTIVATION_SCHEMES = ["static", "dynamic"] _is_hip = is_hip() -_is_cuda = is_cuda() if _is_hip: from aiter.fused_moe_bf16_asm import asm_moe @@ -116,24 +108,7 @@ class Fp8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 80 - if hasattr(torch, "xpu") and torch.xpu.is_available(): - return 0 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> bool: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 80 - if hasattr(torch, "xpu") and torch.xpu.is_available(): - return True - - # Vendors can update - return False + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -875,52 +850,6 @@ class Fp8MoEMethod: ) torch.cuda.empty_cache() - def torch_w8a8_block_fp8_moe( - self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape - ): - from sglang.srt.layers.activation import SiluAndMul - - """This function performs fused moe with block-wise quantization using native torch.""" - - B, D = a.shape - topk = topk_ids.shape[-1] - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_fp8(a, block_k) - # NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``. - a_q = a_q.to(torch.float32) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = native_w8a8_block_fp8_matmul( - a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype, - ) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k - ) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_fp8_matmul( - act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype, - ) - return ( - out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) - ).sum(dim=1) - def apply( self, layer: torch.nn.Module, @@ -994,7 +923,7 @@ class Fp8MoEMethod: layer.w13_weight_scale1, layer.w2_weight_scale1, ) - elif _is_cuda: + else: # Expert fusion with FP8 quantization return fused_experts( x, @@ -1021,24 +950,6 @@ class Fp8MoEMethod: no_combine=no_combine, ) - # for CPU and other accelerators, fallback to native path - return self.torch_w8a8_block_fp8_moe( - a=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_s=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_s=( - layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale - ), - topk_weight=topk_weights, - topk_ids=topk_ids, - block_shape=self.quant_config.weight_block_size, - ) - class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 6f3c0d79e..3c60baf0a 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -28,13 +28,10 @@ from sglang.srt.utils import ( get_device_name, is_cuda, is_hip, - is_triton_available, supports_custom_op, ) _is_hip = is_hip() -_is_cuda = is_cuda() -_is_triton = is_triton_available() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn _is_cuda = is_cuda() @@ -165,34 +162,6 @@ def _per_token_group_quant_fp8_colmajor( tl.store(y_s_ptr, y_s) -def native_per_token_group_quant_fp8( - x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn -): - """Function to perform per-token-group quantization on an input tensor `x` using native torch. - - It converts the tensor values into float8 values and returns the - quantized tensor along with the scaling factor used for quantization. - Note that only `torch.float8_e4m3fn` is supported for now. - """ - assert ( - x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / fp8_max - x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) - - return x_q, x_s - - def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, @@ -263,22 +232,19 @@ def per_token_group_quant_fp8( num_stages=num_stages, ) else: - if _is_triton: - _per_token_group_quant_fp8[(M,)]( - x, - x_q, - x_s, - group_size, - N, - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) - else: - x_q, x_s = native_per_token_group_quant_fp8(x, group_size) + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) return x_q, x_s @@ -725,61 +691,6 @@ def get_w8a8_block_fp8_configs( return None -def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): - """This function performs matrix multiplication with block-wise quantization using native torch. - - It takes two input tensors `A` and `B` with scales `As` and `Bs`. - The output is returned in the specified `output_dtype`. - """ - - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N,) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] - B_tiles = [ - [ - B[ - j * block_n : min((j + 1) * block_n, N), - i * block_k : min((i + 1) * block_k, K), - ] - for i in range(k_tiles) - ] - for j in range(n_tiles) - ] - C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] - As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C - - def w8a8_block_fp8_matmul( A: torch.Tensor, B: torch.Tensor, @@ -804,90 +715,86 @@ def w8a8_block_fp8_matmul( Returns: torch.Tensor: The result of matmul. """ - if _is_triton: # pragma: no cover - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] - assert A.shape[-1] == B.shape[-1] - assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] - M = A.numel() // A.shape[-1] + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - N, K = B.shape - assert triton.cdiv(N, block_n) == Bs.shape[0] - assert triton.cdiv(K, block_k) == Bs.shape[1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N,) - C = A.new_empty(C_shape, dtype=output_dtype) + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) - configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Default config - # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3, - } + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } - def grid(META): - return ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) - * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - - # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. - # Empirical testing shows the sweet spot lies when it's less than the # of - # compute units available on the device. - num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( - N, config["BLOCK_SIZE_N"] + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - # deepgemm only support bf16 - if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: - if supports_custom_op(): - torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) - else: - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) - else: - kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 - if (_is_hip == True and num_workgroups <= get_device_core_count()) - else _w8a8_block_fp8_matmul - ) + # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. + # Empirical testing shows the sweet spot lies when it's less than the # of + # compute units available on the device. + num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( + N, config["BLOCK_SIZE_N"] + ) - kernel[grid]( - A, - B, - C, - As, - Bs, - M, - N, - K, - block_n, - block_k, - A.stride(-2), - A.stride(-1), - B.stride(1), - B.stride(0), - C.stride(-2), - C.stride(-1), - As.stride(-2), - As.stride(-1), - Bs.stride(1), - Bs.stride(0), - **config, - ) + # deepgemm only support bf16 + if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: + if supports_custom_op(): + torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) + else: + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) else: - C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype) + kernel = ( + _w8a8_block_fp8_matmul_unrolledx4 + if (_is_hip == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul + ) + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) return C diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index c2070d23c..b15498864 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -1,5 +1,4 @@ import logging -import sys from fractions import Fraction from typing import Any, Dict, List, Optional, Union @@ -9,7 +8,6 @@ from vllm.scalar_type import scalar_types from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.utils import get_device_capability logger = logging.getLogger(__name__) @@ -92,20 +90,7 @@ class GPTQConfig(QuantizationConfig): @classmethod # Need to figure it out def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 60 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> bool: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 60 - - # Vendors can update - return False + return 60 @classmethod def get_config_filenames(cls) -> List[str]: @@ -224,20 +209,7 @@ class GPTQMarlinConfig(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 80 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> bool: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 80 - - # Vendors can update - return False + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -399,20 +371,7 @@ class MarlinConfig(QuantizationConfig): @classmethod # Need to figure it out def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 80 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> bool: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 80 - - # Vendors can update - return False + return 80 @classmethod def get_config_filenames(cls) -> List[str]: diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 694342000..c26012da2 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py import logging -import sys from typing import Any, Dict, List, Optional import torch @@ -21,7 +20,6 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear -from sglang.srt.utils import get_device_capability # Initialize logger for the module logger = logging.getLogger(__name__) @@ -54,20 +52,7 @@ class ModelOptFp8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 89 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> bool: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 89 - - # Vendors can update - return False + return 89 # Minimum hardware capability (e.g., Hopper GPUs). @classmethod def get_config_filenames(cls) -> List[str]: diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 229b56bec..240d86927 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, ) -from sglang.srt.utils import get_device_capability, is_hip +from sglang.srt.utils import is_hip _is_hip = is_hip() @@ -35,20 +35,7 @@ class W8A8Fp8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 89 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> bool: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 89 - - # Vendors can update - return False + return 89 @classmethod def get_name(self) -> str: diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 07bbde395..6467aca5f 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -1,4 +1,3 @@ -import sys from typing import Any, Callable, Dict, List, Optional import torch @@ -19,7 +18,6 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 -from sglang.srt.utils import get_device_capability class W8A8Int8Config(QuantizationConfig): @@ -38,20 +36,7 @@ class W8A8Int8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return 75 - - # Vendors can update - return sys.maxsize - - @classmethod - def get_availability(cls) -> int: - major, minor = get_device_capability() - if hasattr(torch, "cuda") and torch.cuda.is_available(): - return major * 10 + minor > 75 - - # Vendors can update - return False + return 75 @classmethod def get_name(self) -> str: diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 16dcda8cd..732395765 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -69,7 +69,6 @@ class RotaryEmbedding(CustomOp): base: int, is_neox_style: bool, dtype: torch.dtype, - device: str, ) -> None: super().__init__() self.head_size = head_size @@ -78,7 +77,6 @@ class RotaryEmbedding(CustomOp): self.base = base self.is_neox_style = is_neox_style self.dtype = dtype - self.device = device cache = self._compute_cos_sin_cache() # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability @@ -285,19 +283,12 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factors: Union[List[float], float], dtype: torch.dtype, - device: str, ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors: List[float] = scaling_factors # noqa super().__init__( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - device, + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) # Lazy initialized. self._scaling_factor_to_offset: Dict[float, int] @@ -356,17 +347,10 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, - device: str, ) -> None: self.scaling_factor = scaling_factor super().__init__( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - device, + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -450,7 +434,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, - device: str, *, extrapolation_factor: float = 1, attn_factor: float = 1, @@ -465,13 +448,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): # Get n-d magnitude scaling corrected for interpolation self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - device, + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: @@ -668,7 +645,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, - device: str, *, extrapolation_factor: float = 1, attn_factor: float = 1, @@ -676,6 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, + device: Optional[str] = "cuda", ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -688,14 +665,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor ) + self.device = device super().__init__( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - device, + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: @@ -790,7 +762,6 @@ class Llama3RotaryEmbedding(RotaryEmbedding): base: int, is_neox_style: bool, dtype: torch.dtype, - device: str, scaling_factor: float, low_freq_factor: float, high_freq_factor: float, @@ -801,13 +772,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): self.high_freq_factor = high_freq_factor self.orig_max_position = orig_max_position super().__init__( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - str, + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: @@ -845,17 +810,10 @@ class MRotaryEmbedding(RotaryEmbedding): base: int, is_neox_style: bool, dtype: torch.dtype, - device: str, mrope_section: Optional[List[int]] = None, ) -> None: super().__init__( - head_size, - rotary_dim, - max_position_embeddings, - base, - is_neox_style, - dtype, - device, + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) self.mrope_section = mrope_section @@ -1045,14 +1003,9 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, - device: str = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() - if device is None: - from sglang.srt.managers.schedule_batch import global_server_args_dict - - device = global_server_args_dict["device"] if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { @@ -1077,7 +1030,7 @@ def get_rope( if rope_scaling is None: rotary_emb = RotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, dtype, device + head_size, rotary_dim, max_position, base, is_neox_style, dtype ) else: if "rope_type" in rope_scaling: @@ -1099,7 +1052,6 @@ def get_rope( base, is_neox_style, dtype, - device, scaling_factor, low_freq_factor, high_freq_factor, @@ -1114,7 +1066,6 @@ def get_rope( base, is_neox_style, dtype, - device, mrope_section=rope_scaling["mrope_section"], ) else: @@ -1125,7 +1076,6 @@ def get_rope( base, is_neox_style, dtype, - device, ) elif scaling_type == "linear": scaling_factor = rope_scaling["factor"] @@ -1137,7 +1087,6 @@ def get_rope( is_neox_style, scaling_factor, dtype, - device, ) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] @@ -1149,7 +1098,6 @@ def get_rope( is_neox_style, scaling_factor, dtype, - device, ) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] @@ -1168,7 +1116,6 @@ def get_rope( is_neox_style, scaling_factor, dtype, - device, **extra_kwargs, ) elif scaling_type == "deepseek_yarn": @@ -1196,7 +1143,6 @@ def get_rope( is_neox_style, scaling_factor, dtype, - device, **extra_kwargs, ) elif scaling_type == "longrope": @@ -1307,8 +1253,21 @@ def get_rope_wrapper( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + device: Optional[str] = None, ): - return get_rope( + if device != "cpu": + return get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + ) + + return get_rope_cpu( head_size, rotary_dim, max_position, @@ -1317,4 +1276,5 @@ def get_rope_wrapper( rope_scaling, dtype, partial_rotary_factor, + device, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index dbbe56033..1b447b2b8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -35,8 +35,6 @@ from sglang.srt.distributed import ( set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state -from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend -from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 41531dd0c..c241fd9d6 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -108,25 +108,15 @@ def _get_quantization_config( quant_config = get_quant_config(model_config, load_config) major, minor = get_device_capability() - if not hasattr(quant_config, "get_availability"): - # Update VLLM to support get_available - if major is not None and minor is not None: - assert 0 <= minor < 10 - capability = major * 10 + minor - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} " - "is not supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}." - ) - else: - if not quant_config.get_availability(): + if major is not None and minor is not None: + assert 0 <= minor < 10 + capability = major * 10 + minor + if capability < quant_config.get_min_capability(): raise ValueError( f"The quantization method {model_config.quantization} " "is not supported for the current GPU. " f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {major, minor}." + f"Current capability: {capability}." ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ddd460b3f..d0ca14feb 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import ( block_dequant as int8_block_dequant, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope_wrapper +from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -305,6 +305,7 @@ class DeepseekV2Attention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, + device=global_server_args_dict["device"], ) if rope_scaling: @@ -500,7 +501,7 @@ class DeepseekV2AttentionMLA(nn.Module): if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope_wrapper( + self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, @@ -645,20 +646,19 @@ class DeepseekV2AttentionMLA(nn.Module): ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only + if self.w_kc.dtype == torch.float8_e4m3fnuz: # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, ) - elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available(): + elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 ) - else: q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) @@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module): attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available(): + if self.w_vc.dtype == torch.float8_e4m3fnuz: # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, ) - elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available(): + elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn ) @@ -694,7 +694,6 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_scale, torch.bfloat16, ) - else: attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 4dbc0ab04..19c06b607 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -126,14 +126,6 @@ def is_cuda_available(): return is_cuda() -def is_triton_available(): - if is_cuda() or is_xpu() or is_hip(): - return get_bool_env_var("TRITON_AVAILABLE", default="true") - else: - # update once CPU/HPU supports triton - return False - - def enable_show_time_cost(): global show_time_cost show_time_cost = True @@ -1144,7 +1136,6 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: major, minor = None, None if hasattr(torch, "cuda") and torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability(device_id) - assert 0 <= minor < 10 if hasattr(torch, "xpu") and torch.xpu.is_available(): major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split( diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index a87354990..25aaf498a 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -7,8 +7,6 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( - native_per_token_group_quant_fp8, - native_w8a8_block_fp8_matmul, per_token_group_quant_fp8, static_quant_fp8, w8a8_block_fp8_matmul, @@ -17,6 +15,35 @@ from sglang.srt.layers.quantization.fp8_kernel import ( _is_cuda = torch.cuda.is_available() and torch.version.cuda +# For test +def native_per_token_group_quant_fp8( + x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn +): + """Function to perform per-token-group quantization on an input tensor `x` using native torch. + + It converts the tensor values into float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Note that only `torch.float8_e4m3fn` is supported for now. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) + + return x_q, x_s + + class TestPerTokenGroupQuantFP8(unittest.TestCase): DTYPES = [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] @@ -127,6 +154,62 @@ class TestStaticQuantFP8(unittest.TestCase): self._static_quant_fp8(*params) +# For test +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N,) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + class TestW8A8BlockFP8Matmul(unittest.TestCase): if not _is_cuda: