diff --git a/docs/start/install.md b/docs/start/install.md index 3b357fa3e..f7fd11141 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -52,6 +52,15 @@ cd .. pip install -e "python[all_hip]" ``` +Note: To Intel GPU, do following instead: + +``` +git clone https://github.com/sgl-project/sglang.git +cd sglang +pip install --upgrade pip +pip install -e "python[all_xpu]" +``` + ## Method 3: Using docker The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). diff --git a/python/pyproject.toml b/python/pyproject.toml index a0ad844bd..e81edaa28 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -57,7 +57,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1 # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm -srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"] +srt_xpu = ["sglang[runtime_common]", "vllm>=0.6.4.post1,<=0.7.2", "outlines>=0.0.44,<=0.1.11"] # For Intel Gaudi(device : hpu) follow the installation guide # https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index e45dda2cc..63c9a6874 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -54,8 +54,18 @@ class QuantizationConfig(ABC): """Minimum GPU capability to support the quantization method. E.g., 70 for Volta, 75 for Turing, 80 for Ampere. - This requirement is due to the custom CUDA kernels used by the - quantization method. + This requirement is due to the custom kernels used by the + quantization method or the stock pytorch capability. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def get_availability(cls) -> bool: + """Whether the quantization config is available on current device. + + This requirement is due to the custom kernels used by the + quantization method or the stock pytorch capability. """ raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index ce526cd6a..4c633b328 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -1,6 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging +import sys from typing import Any, Callable, Dict, List, Optional import torch @@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import get_device_capability, set_weight_attrs ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -71,7 +72,20 @@ class BlockInt8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 80 + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 80 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> bool: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 80 + + # Vendors can update + return False @classmethod def get_config_filenames(cls) -> List[str]: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index bff0fe96e..a39149a2b 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,6 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging +import sys from typing import Any, Callable, Dict, List, Optional import torch @@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 +from sglang.srt.layers.quantization.fp8_kernel import ( + native_per_token_group_quant_fp8, + native_w8a8_block_fp8_matmul, + per_token_group_quant_fp8, +) from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, apply_w8a8_block_fp8_linear, @@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import ( ) from sglang.srt.utils import ( get_bool_env_var, + get_device_capability, + is_cuda, is_hip, permute_weight, print_warning_once, @@ -55,6 +62,7 @@ from sglang.srt.utils import ( ACTIVATION_SCHEMES = ["static", "dynamic"] _is_hip = is_hip() +_is_cuda = is_cuda() if _is_hip: from aiter.fused_moe_bf16_asm import asm_moe @@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 80 + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 80 + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return 0 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> bool: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 80 + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return True + + # Vendors can update + return False @classmethod def get_config_filenames(cls) -> List[str]: @@ -850,6 +875,52 @@ class Fp8MoEMethod: ) torch.cuda.empty_cache() + def torch_w8a8_block_fp8_moe( + self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape + ): + from sglang.srt.layers.activation import SiluAndMul + + """This function performs fused moe with block-wise quantization using native torch.""" + + B, D = a.shape + topk = topk_ids.shape[-1] + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + # NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``. + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul( + a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype, + ) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k + ) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul( + act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype, + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + def apply( self, layer: torch.nn.Module, @@ -923,7 +994,7 @@ class Fp8MoEMethod: layer.w13_weight_scale1, layer.w2_weight_scale1, ) - else: + elif _is_cuda: # Expert fusion with FP8 quantization return fused_experts( x, @@ -950,6 +1021,24 @@ class Fp8MoEMethod: no_combine=no_combine, ) + # for CPU and other accelerators, fallback to native path + return self.torch_w8a8_block_fp8_moe( + a=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_s=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_s=( + layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale + ), + topk_weight=topk_weights, + topk_ids=topk_ids, + block_shape=self.quant_config.weight_block_size, + ) + class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 3c60baf0a..6f3c0d79e 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -28,10 +28,13 @@ from sglang.srt.utils import ( get_device_name, is_cuda, is_hip, + is_triton_available, supports_custom_op, ) _is_hip = is_hip() +_is_cuda = is_cuda() +_is_triton = is_triton_available() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn _is_cuda = is_cuda() @@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor( tl.store(y_s_ptr, y_s) +def native_per_token_group_quant_fp8( + x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn +): + """Function to perform per-token-group quantization on an input tensor `x` using native torch. + + It converts the tensor values into float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Note that only `torch.float8_e4m3fn` is supported for now. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) + + return x_q, x_s + + def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, @@ -232,19 +263,22 @@ def per_token_group_quant_fp8( num_stages=num_stages, ) else: - _per_token_group_quant_fp8[(M,)]( - x, - x_q, - x_s, - group_size, - N, - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) + if _is_triton: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + x_q, x_s = native_per_token_group_quant_fp8(x, group_size) return x_q, x_s @@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs( return None +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N,) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + def w8a8_block_fp8_matmul( A: torch.Tensor, B: torch.Tensor, @@ -715,86 +804,90 @@ def w8a8_block_fp8_matmul( Returns: torch.Tensor: The result of matmul. """ - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] + if _is_triton: # pragma: no cover + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] - assert A.shape[-1] == B.shape[-1] - assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] - M = A.numel() // A.shape[-1] + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - N, K = B.shape - assert triton.cdiv(N, block_n) == Bs.shape[0] - assert triton.cdiv(K, block_k) == Bs.shape[1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N,) - C = A.new_empty(C_shape, dtype=output_dtype) + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) - configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Default config - # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3, - } - - def grid(META): - return ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - - # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. - # Empirical testing shows the sweet spot lies when it's less than the # of - # compute units available on the device. - num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( - N, config["BLOCK_SIZE_N"] - ) - - # deepgemm only support bf16 - if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: - if supports_custom_op(): - torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) - else: - kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 - if (_is_hip == True and num_workgroups <= get_device_core_count()) - else _w8a8_block_fp8_matmul + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. + # Empirical testing shows the sweet spot lies when it's less than the # of + # compute units available on the device. + num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( + N, config["BLOCK_SIZE_N"] ) - kernel[grid]( - A, - B, - C, - As, - Bs, - M, - N, - K, - block_n, - block_k, - A.stride(-2), - A.stride(-1), - B.stride(1), - B.stride(0), - C.stride(-2), - C.stride(-1), - As.stride(-2), - As.stride(-1), - Bs.stride(1), - Bs.stride(0), - **config, - ) + # deepgemm only support bf16 + if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: + if supports_custom_op(): + torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) + else: + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + else: + kernel = ( + _w8a8_block_fp8_matmul_unrolledx4 + if (_is_hip == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul + ) + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + else: + C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype) return C diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index b15498864..c2070d23c 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -1,4 +1,5 @@ import logging +import sys from fractions import Fraction from typing import Any, Dict, List, Optional, Union @@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.utils import get_device_capability logger = logging.getLogger(__name__) @@ -90,7 +92,20 @@ class GPTQConfig(QuantizationConfig): @classmethod # Need to figure it out def get_min_capability(cls) -> int: - return 60 + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 60 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> bool: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 60 + + # Vendors can update + return False @classmethod def get_config_filenames(cls) -> List[str]: @@ -209,7 +224,20 @@ class GPTQMarlinConfig(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 80 + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 80 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> bool: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 80 + + # Vendors can update + return False @classmethod def get_config_filenames(cls) -> List[str]: @@ -371,7 +399,20 @@ class MarlinConfig(QuantizationConfig): @classmethod # Need to figure it out def get_min_capability(cls) -> int: - return 80 + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 80 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> bool: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 80 + + # Vendors can update + return False @classmethod def get_config_filenames(cls) -> List[str]: diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index c26012da2..694342000 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1,6 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py import logging +import sys from typing import Any, Dict, List, Optional import torch @@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear +from sglang.srt.utils import get_device_capability # Initialize logger for the module logger = logging.getLogger(__name__) @@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 89 # Minimum hardware capability (e.g., Hopper GPUs). + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 89 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> bool: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 89 + + # Vendors can update + return False @classmethod def get_config_filenames(cls) -> List[str]: diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 240d86927..229b56bec 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, ) -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_device_capability, is_hip _is_hip = is_hip() @@ -35,7 +35,20 @@ class W8A8Fp8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 89 + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 89 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> bool: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 89 + + # Vendors can update + return False @classmethod def get_name(self) -> str: diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 6467aca5f..07bbde395 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -1,3 +1,4 @@ +import sys from typing import Any, Callable, Dict, List, Optional import torch @@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 +from sglang.srt.utils import get_device_capability class W8A8Int8Config(QuantizationConfig): @@ -36,7 +38,20 @@ class W8A8Int8Config(QuantizationConfig): @classmethod def get_min_capability(cls) -> int: - return 75 + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return 75 + + # Vendors can update + return sys.maxsize + + @classmethod + def get_availability(cls) -> int: + major, minor = get_device_capability() + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return major * 10 + minor > 75 + + # Vendors can update + return False @classmethod def get_name(self) -> str: diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 732395765..16dcda8cd 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp): base: int, is_neox_style: bool, dtype: torch.dtype, + device: str, ) -> None: super().__init__() self.head_size = head_size @@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp): self.base = base self.is_neox_style = is_neox_style self.dtype = dtype + self.device = device cache = self._compute_cos_sin_cache() # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability @@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factors: Union[List[float], float], dtype: torch.dtype, + device: str, ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors: List[float] = scaling_factors # noqa super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, ) # Lazy initialized. self._scaling_factor_to_offset: Dict[float, int] @@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, + device: str, ) -> None: self.scaling_factor = scaling_factor super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, ) def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, + device: str, *, extrapolation_factor: float = 1, attn_factor: float = 1, @@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): # Get n-d magnitude scaling corrected for interpolation self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: @@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, scaling_factor: float, dtype: torch.dtype, + device: str, *, extrapolation_factor: float = 1, attn_factor: float = 1, @@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, - device: Optional[str] = "cuda", ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor ) - self.device = device super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: @@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): base: int, is_neox_style: bool, dtype: torch.dtype, + device: str, scaling_factor: float, low_freq_factor: float, high_freq_factor: float, @@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding): self.high_freq_factor = high_freq_factor self.orig_max_position = orig_max_position super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + str, ) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: @@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding): base: int, is_neox_style: bool, dtype: torch.dtype, + device: str, mrope_section: Optional[List[int]] = None, ) -> None: super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + device, ) self.mrope_section = mrope_section @@ -1003,9 +1045,14 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + device: str = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() + if device is None: + from sglang.srt.managers.schedule_batch import global_server_args_dict + + device = global_server_args_dict["device"] if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { @@ -1030,7 +1077,7 @@ def get_rope( if rope_scaling is None: rotary_emb = RotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, dtype + head_size, rotary_dim, max_position, base, is_neox_style, dtype, device ) else: if "rope_type" in rope_scaling: @@ -1052,6 +1099,7 @@ def get_rope( base, is_neox_style, dtype, + device, scaling_factor, low_freq_factor, high_freq_factor, @@ -1066,6 +1114,7 @@ def get_rope( base, is_neox_style, dtype, + device, mrope_section=rope_scaling["mrope_section"], ) else: @@ -1076,6 +1125,7 @@ def get_rope( base, is_neox_style, dtype, + device, ) elif scaling_type == "linear": scaling_factor = rope_scaling["factor"] @@ -1087,6 +1137,7 @@ def get_rope( is_neox_style, scaling_factor, dtype, + device, ) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] @@ -1098,6 +1149,7 @@ def get_rope( is_neox_style, scaling_factor, dtype, + device, ) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] @@ -1116,6 +1168,7 @@ def get_rope( is_neox_style, scaling_factor, dtype, + device, **extra_kwargs, ) elif scaling_type == "deepseek_yarn": @@ -1143,6 +1196,7 @@ def get_rope( is_neox_style, scaling_factor, dtype, + device, **extra_kwargs, ) elif scaling_type == "longrope": @@ -1253,21 +1307,8 @@ def get_rope_wrapper( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, - device: Optional[str] = None, ): - if device != "cpu": - return get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - rope_scaling, - dtype, - partial_rotary_factor, - ) - - return get_rope_cpu( + return get_rope( head_size, rotary_dim, max_position, @@ -1276,5 +1317,4 @@ def get_rope_wrapper( rope_scaling, dtype, partial_rotary_factor, - device, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1b447b2b8..dbbe56033 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -35,6 +35,8 @@ from sglang.srt.distributed import ( set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend +from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index c241fd9d6..41531dd0c 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -108,15 +108,25 @@ def _get_quantization_config( quant_config = get_quant_config(model_config, load_config) major, minor = get_device_capability() - if major is not None and minor is not None: - assert 0 <= minor < 10 - capability = major * 10 + minor - if capability < quant_config.get_min_capability(): + if not hasattr(quant_config, "get_availability"): + # Update VLLM to support get_available + if major is not None and minor is not None: + assert 0 <= minor < 10 + capability = major * 10 + minor + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) + else: + if not quant_config.get_availability(): raise ValueError( f"The quantization method {model_config.quantization} " "is not supported for the current GPU. " f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}." + f"Current capability: {major, minor}." ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d0ca14feb..ddd460b3f 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import ( block_dequant as int8_block_dequant, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper +from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module): base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, - device=global_server_args_dict["device"], ) if rope_scaling: @@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module): if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope( + self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, @@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module): ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if self.w_kc.dtype == torch.float8_e4m3fnuz: + if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz q_nope_out = torch.bmm( q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16) * self.w_scale, ) - elif self.w_kc.dtype == torch.float8_e4m3fn: + elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available(): q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn ) q_nope_out = bmm_fp8( q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 ) + else: q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) @@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module): attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - if self.w_vc.dtype == torch.float8_e4m3fnuz: + if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available(): # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz attn_bmm_output = torch.bmm( attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16) * self.w_scale, ) - elif self.w_vc.dtype == torch.float8_e4m3fn: + elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available(): attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn ) @@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_scale, torch.bfloat16, ) + else: attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 19c06b607..4dbc0ab04 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -126,6 +126,14 @@ def is_cuda_available(): return is_cuda() +def is_triton_available(): + if is_cuda() or is_xpu() or is_hip(): + return get_bool_env_var("TRITON_AVAILABLE", default="true") + else: + # update once CPU/HPU supports triton + return False + + def enable_show_time_cost(): global show_time_cost show_time_cost = True @@ -1136,6 +1144,7 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: major, minor = None, None if hasattr(torch, "cuda") and torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability(device_id) + assert 0 <= minor < 10 if hasattr(torch, "xpu") and torch.xpu.is_available(): major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split( diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index 25aaf498a..a87354990 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -7,6 +7,8 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( + native_per_token_group_quant_fp8, + native_w8a8_block_fp8_matmul, per_token_group_quant_fp8, static_quant_fp8, w8a8_block_fp8_matmul, @@ -15,35 +17,6 @@ from sglang.srt.layers.quantization.fp8_kernel import ( _is_cuda = torch.cuda.is_available() and torch.version.cuda -# For test -def native_per_token_group_quant_fp8( - x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn -): - """Function to perform per-token-group quantization on an input tensor `x` using native torch. - - It converts the tensor values into float8 values and returns the - quantized tensor along with the scaling factor used for quantization. - Note that only `torch.float8_e4m3fn` is supported for now. - """ - assert ( - x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" - assert x.is_contiguous(), "`x` is not contiguous" - - finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / fp8_max - x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) - - return x_q, x_s - - class TestPerTokenGroupQuantFP8(unittest.TestCase): DTYPES = [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] @@ -154,62 +127,6 @@ class TestStaticQuantFP8(unittest.TestCase): self._static_quant_fp8(*params) -# For test -def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): - """This function performs matrix multiplication with block-wise quantization using native torch. - - It takes two input tensors `A` and `B` with scales `As` and `Bs`. - The output is returned in the specified `output_dtype`. - """ - - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N,) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] - B_tiles = [ - [ - B[ - j * block_n : min((j + 1) * block_n, N), - i * block_k : min((i + 1) * block_k, K), - ] - for i in range(k_tiles) - ] - for j in range(n_tiles) - ] - C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] - As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C - - class TestW8A8BlockFP8Matmul(unittest.TestCase): if not _is_cuda: