Clean up imports (#5467)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.api import (
|
|||||||
user_end,
|
user_end,
|
||||||
video,
|
video,
|
||||||
)
|
)
|
||||||
|
from sglang.global_config import global_config
|
||||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.lang.choices import (
|
from sglang.lang.choices import (
|
||||||
greedy_token_selection,
|
greedy_token_selection,
|
||||||
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
|
|||||||
unconditional_likelihood_normalized,
|
unconditional_likelihood_normalized,
|
||||||
)
|
)
|
||||||
from sglang.utils import LazyImport
|
from sglang.utils import LazyImport
|
||||||
|
from sglang.version import __version__
|
||||||
|
|
||||||
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
||||||
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
||||||
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
|||||||
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
||||||
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
||||||
|
|
||||||
# Other configs
|
|
||||||
from sglang.global_config import global_config
|
|
||||||
from sglang.version import __version__
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Engine",
|
"Engine",
|
||||||
"Runtime",
|
"Runtime",
|
||||||
|
|||||||
@@ -707,10 +707,6 @@ def sample_random_requests(
|
|||||||
|
|
||||||
# Download sharegpt if necessary
|
# Download sharegpt if necessary
|
||||||
if not os.path.isfile(dataset_path):
|
if not os.path.isfile(dataset_path):
|
||||||
print(
|
|
||||||
"If you do not want to randomly sample from a dataset,"
|
|
||||||
" please use --dataset-name random-ids."
|
|
||||||
)
|
|
||||||
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from sglang.lang.backend.base_backend import BaseBackend
|
from sglang.lang.backend.base_backend import BaseBackend
|
||||||
from sglang.lang.chat_template import get_chat_template
|
from sglang.lang.chat_template import get_chat_template
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from sglang.lang.chat_template import get_chat_template
|
from sglang.lang.chat_template import get_chat_template
|
||||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import dataclasses
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sglang.lang.backend.base_backend import BaseBackend
|
from sglang.lang.backend.base_backend import BaseBackend
|
||||||
from sglang.lang.chat_template import get_chat_template
|
from sglang.lang.chat_template import get_chat_template
|
||||||
|
|||||||
@@ -5,13 +5,7 @@ from typing import List, Union
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
||||||
from sglang.lang.ir import (
|
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
|
||||||
SglArgument,
|
|
||||||
SglConstantText,
|
|
||||||
SglExpr,
|
|
||||||
SglSamplingParams,
|
|
||||||
SglVariable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def compile_func(function, backend):
|
def compile_func(function, backend):
|
||||||
|
|||||||
@@ -1,20 +1,16 @@
|
|||||||
"""Tracing a program."""
|
"""Tracing a program."""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
|
||||||
from sglang.lang.backend.base_backend import BaseBackend
|
from sglang.lang.backend.base_backend import BaseBackend
|
||||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||||
from sglang.lang.ir import (
|
from sglang.lang.ir import (
|
||||||
SglArgument,
|
SglArgument,
|
||||||
SglCommitLazy,
|
|
||||||
SglConcateAndAppend,
|
|
||||||
SglConstantText,
|
SglConstantText,
|
||||||
SglExpr,
|
SglExpr,
|
||||||
SglExprList,
|
SglExprList,
|
||||||
SglFork,
|
SglFork,
|
||||||
SglFunction,
|
|
||||||
SglGen,
|
SglGen,
|
||||||
SglGetForkItem,
|
SglGetForkItem,
|
||||||
SglRoleBegin,
|
SglRoleBegin,
|
||||||
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
|
|||||||
self.cur_role = None
|
self.cur_role = None
|
||||||
|
|
||||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||||
new_node = SglVariable(name, source=self.last_node)
|
new_node = SglVariable(expr.name, source=self.last_node)
|
||||||
self.variables[name] = new_node
|
self.variables[expr.name] = new_node
|
||||||
|
|
||||||
def get_var(self, name):
|
def get_var(self, name):
|
||||||
ret = self.arguments.get(name, None)
|
ret = self.arguments.get(name, None)
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.library
|
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
|
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
|
||||||
|
|
||||||
|
|||||||
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
|
|||||||
return self.forward_hip
|
return self.forward_hip
|
||||||
else:
|
else:
|
||||||
return self.forward_native
|
return self.forward_native
|
||||||
|
|
||||||
|
|
||||||
if _is_cuda:
|
|
||||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
|
||||||
|
|
||||||
def scaled_fp8_quant(
|
|
||||||
input: torch.Tensor,
|
|
||||||
scale: Optional[torch.Tensor] = None,
|
|
||||||
num_token_padding: Optional[int] = None,
|
|
||||||
use_per_token_if_dynamic: bool = False,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input (torch.Tensor): Input tensor to be quantized
|
|
||||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
|
||||||
If None, scales will be computed dynamically.
|
|
||||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
|
||||||
of the output to at least this value.
|
|
||||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
|
||||||
determines the quantization granularity:
|
|
||||||
- True: compute scale per token
|
|
||||||
- False: compute single scale per tensor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
||||||
- quantized_tensor: The FP8 quantized version of input
|
|
||||||
- scale_tensor: The scaling factors used for quantization
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
|
||||||
"""
|
|
||||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
|
||||||
shape = input.shape
|
|
||||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
if num_token_padding:
|
|
||||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
|
||||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
|
||||||
|
|
||||||
if scale is None:
|
|
||||||
# Dynamic scaling
|
|
||||||
if use_per_token_if_dynamic:
|
|
||||||
scale = torch.empty(
|
|
||||||
(shape[0], 1), device=input.device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
sgl_per_token_quant_fp8(input, output, scale)
|
|
||||||
else:
|
|
||||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
|
||||||
sgl_per_tensor_quant_fp8(
|
|
||||||
input, output, scale, is_static=False
|
|
||||||
) # False for dynamic
|
|
||||||
else:
|
|
||||||
# Static scaling
|
|
||||||
assert (
|
|
||||||
scale.numel() == 1
|
|
||||||
), f"Expected scalar scale, got numel={scale.numel()}"
|
|
||||||
sgl_per_tensor_quant_fp8(
|
|
||||||
input, output, scale, is_static=True
|
|
||||||
) # True for static
|
|
||||||
|
|
||||||
return output, scale
|
|
||||||
|
|||||||
@@ -19,11 +19,10 @@ import torch.distributed as dist
|
|||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
||||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||||
from sglang.srt.server import Engine
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
||||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,13 +21,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda_available
|
|
||||||
|
|
||||||
_is_cuda = is_cuda_available()
|
|
||||||
|
|
||||||
if _is_cuda:
|
|
||||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
divide,
|
divide,
|
||||||
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||||
|
|
||||||
|
_is_cuda = is_cuda_available()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
_is_cuda = is_cuda_available()
|
_is_cuda = is_cuda_available()
|
||||||
@@ -31,7 +32,6 @@ if _is_cuda:
|
|||||||
rmsnorm,
|
rmsnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import logging
|
|||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_gemm import (
|
from deep_gemm import (
|
||||||
@@ -13,8 +14,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
use_deep_gemm = False
|
use_deep_gemm = False
|
||||||
|
|
||||||
from torch.nn import Module
|
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
@@ -37,21 +36,16 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
|
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
|
||||||
|
|
||||||
if _is_cuda:
|
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
|
||||||
else:
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
_buffer = None
|
if _is_hip:
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GroupedGemmRunner(torch.nn.Module):
|
class GroupedGemmRunner(torch.nn.Module):
|
||||||
@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for expert in range(layer.num_experts_per_partition):
|
for expert in range(layer.num_experts_per_partition):
|
||||||
if _is_cuda:
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
)
|
||||||
)
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
|
||||||
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
|
||||||
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
|
||||||
|
|
||||||
enable_moe_align_block_size_triton = bool(
|
|
||||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
|
||||||
)
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
|
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
|
||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
if _is_cuda or _is_hip:
|
if _is_cuda or _is_hip:
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
|
enable_moe_align_block_size_triton = bool(
|
||||||
|
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def write_zeros_to_output(
|
def write_zeros_to_output(
|
||||||
c_ptr,
|
c_ptr,
|
||||||
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
|
|||||||
# activation tensor-wise fp8 quantization, dynamic or static
|
# activation tensor-wise fp8 quantization, dynamic or static
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
# activations apply per-token quantization when weights apply per-channel quantization by default
|
# activations apply per-token quantization when weights apply per-channel quantization by default
|
||||||
if _is_cuda:
|
A, A_scale = scaled_fp8_quant(
|
||||||
A, A_scale = sgl_scaled_fp8_quant(
|
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
||||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
A, A_scale = vllm_ops.scaled_fp8_quant(
|
|
||||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# activation block-wise fp8 quantization
|
# activation block-wise fp8 quantization
|
||||||
assert len(block_shape) == 2
|
assert len(block_shape) == 2
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -29,6 +28,10 @@ _is_hip = is_hip()
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import moe_fused_gate
|
from sgl_kernel import moe_fused_gate
|
||||||
|
|
||||||
|
if _is_cuda or _is_hip:
|
||||||
|
from sgl_kernel import topk_softmax
|
||||||
|
|
||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
|
|
||||||
@@ -59,11 +62,6 @@ def fused_topk(
|
|||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
):
|
):
|
||||||
if _is_cuda or _is_hip:
|
|
||||||
from sgl_kernel import topk_softmax
|
|
||||||
else:
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
M, _ = hidden_states.shape
|
||||||
@@ -76,20 +74,12 @@ def fused_topk(
|
|||||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_cuda or _is_hip:
|
topk_softmax(
|
||||||
topk_softmax(
|
topk_weights,
|
||||||
topk_weights,
|
topk_ids,
|
||||||
topk_ids,
|
token_expert_indicies,
|
||||||
token_expert_indicies,
|
gating_output.float(),
|
||||||
gating_output.float(),
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
vllm_ops.topk_softmax(
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
token_expert_indicies,
|
|
||||||
gating_output.float(),
|
|
||||||
)
|
|
||||||
del token_expert_indicies
|
del token_expert_indicies
|
||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BasevLLMParameter",
|
"BasevLLMParameter",
|
||||||
"PackedvLLMParameter",
|
"PackedvLLMParameter",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
|||||||
is_activation_quantization_format,
|
is_activation_quantization_format,
|
||||||
should_ignore_layer,
|
should_ignore_layer,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,16 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors import CompressionFormat
|
from compressed_tensors import CompressionFormat
|
||||||
from compressed_tensors.quantization import QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
|
||||||
FusedMoE,
|
|
||||||
FusedMoEMethodBase,
|
|
||||||
FusedMoeWeightScaleSupported,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
|
|||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_cuda:
|
if not _is_cuda:
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
|
||||||
else:
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
@@ -58,8 +51,6 @@ __all__ = [
|
|||||||
|
|
||||||
class CompressedTensorsMoEMethod:
|
class CompressedTensorsMoEMethod:
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
|
||||||
|
|
||||||
if cls is CompressedTensorsMoEMethod:
|
if cls is CompressedTensorsMoEMethod:
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
|
|||||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
if not VLLM_AVAILABLE:
|
if not VLLM_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
|
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
|
||||||
)
|
)
|
||||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||||
):
|
):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
|
||||||
FusedMoEMethodBase,
|
|
||||||
FusedMoeWeightScaleSupported,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||||
layer.w13_weight_scale[expert_id][shard_id],
|
layer.w13_weight_scale[expert_id][shard_id],
|
||||||
)
|
)
|
||||||
|
(
|
||||||
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||||
|
_,
|
||||||
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||||
|
|
||||||
if _is_cuda:
|
|
||||||
(
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
|
||||||
_,
|
|
||||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
|
||||||
_,
|
|
||||||
) = vllm_ops.scaled_fp8_quant(
|
|
||||||
dq_weight, max_w13_scales[expert_id]
|
|
||||||
)
|
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||||
):
|
):
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
|
||||||
FusedMoEMethodBase,
|
|
||||||
FusedMoeWeightScaleSupported,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||||
# are supported + check if the layer is being ignored.
|
# are supported + check if the layer is being ignored.
|
||||||
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||||
layer.w13_weight_packed,
|
layer.w13_weight_packed,
|
||||||
layer.w13_g_idx_sort_indices,
|
layer.w13_g_idx_sort_indices,
|
||||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||||
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
self.num_bits,
|
self.num_bits,
|
||||||
)
|
)
|
||||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
||||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||||
layer.w2_weight_packed,
|
layer.w2_weight_packed,
|
||||||
layer.w2_g_idx_sort_indices,
|
layer.w2_g_idx_sort_indices,
|
||||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||||
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
if not VLLM_AVAILABLE:
|
|
||||||
raise ImportError(
|
|
||||||
"vllm is not installed, to use fused_marlin_moe, please install vllm"
|
|
||||||
)
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
Fp8LinearOp,
|
Fp8LinearOp,
|
||||||
maybe_create_device_identity,
|
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
||||||
@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
weight_loader: Callable,
|
weight_loader: Callable,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
maybe_create_device_identity()
|
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
|||||||
@@ -8,15 +8,6 @@ import torch.nn.functional as F
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
||||||
from sglang.srt.layers.quantization.utils import (
|
|
||||||
all_close_1d,
|
|
||||||
convert_to_channelwise,
|
|
||||||
is_layer_skipped,
|
|
||||||
per_tensor_dequantize,
|
|
||||||
requantize_with_max_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear,
|
apply_fp8_marlin_linear,
|
||||||
@@ -27,11 +18,12 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
MARLIN_FP8_AVAILABLE = False
|
MARLIN_FP8_AVAILABLE = False
|
||||||
|
|
||||||
def apply_fp8_marlin_linear(*args, **kwargs):
|
def dummy_func(*args, **kwargs):
|
||||||
raise ImportError("vllm is not installed")
|
raise ImportError(
|
||||||
|
"marlin FP8 requires some operators from vllm. Please install vllm."
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
|
||||||
raise ImportError("vllm is not installed")
|
|
||||||
|
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
per_token_group_quant_fp8,
|
||||||
|
scaled_fp8_quant,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
apply_fp8_linear,
|
apply_fp8_linear,
|
||||||
apply_w8a8_block_fp8_linear,
|
apply_w8a8_block_fp8_linear,
|
||||||
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|||||||
input_to_float8,
|
input_to_float8,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from sglang.srt.layers.quantization.utils import (
|
||||||
|
all_close_1d,
|
||||||
|
convert_to_channelwise,
|
||||||
|
is_layer_skipped,
|
||||||
|
per_tensor_dequantize,
|
||||||
|
requantize_with_max_scale,
|
||||||
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
permute_weight,
|
|
||||||
print_warning_once,
|
print_warning_once,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from aiter import ActivationType
|
from aiter import ActivationType
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
if not _is_cuda:
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
if _is_cuda:
|
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
else:
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
layer.orig_dtype = params_dtype
|
layer.orig_dtype = params_dtype
|
||||||
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight_scale_inv.data, requires_grad=False
|
layer.weight_scale_inv.data, requires_grad=False
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
if self.cutlass_fp8_supported or self.use_marlin:
|
if self.cutlass_fp8_supported or self.use_marlin:
|
||||||
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
try:
|
prepare_fp8_layer_for_marlin(layer)
|
||||||
prepare_fp8_layer_for_marlin(layer)
|
# Activations not quantized for marlin.
|
||||||
# Activations not quantized for marlin.
|
del layer.input_scale
|
||||||
del layer.input_scale
|
|
||||||
except ImportError:
|
|
||||||
self.use_marlin = False
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
try:
|
return apply_fp8_marlin_linear(
|
||||||
return apply_fp8_marlin_linear(
|
input=x,
|
||||||
input=x,
|
weight=layer.weight,
|
||||||
weight=layer.weight,
|
weight_scale=layer.weight_scale,
|
||||||
weight_scale=layer.weight_scale,
|
workspace=layer.workspace,
|
||||||
workspace=layer.workspace,
|
size_n=layer.output_size_per_partition,
|
||||||
size_n=layer.output_size_per_partition,
|
size_k=layer.input_size_per_partition,
|
||||||
size_k=layer.input_size_per_partition,
|
bias=bias,
|
||||||
bias=bias,
|
)
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
self.use_marlin = False
|
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
return apply_w8a8_block_fp8_linear(
|
return apply_w8a8_block_fp8_linear(
|
||||||
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# WEIGHTS
|
# WEIGHTS
|
||||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
# INT4 MoE weight - INT32 packed
|
# INT4 MoE weight - INT32 packed
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
|
|||||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
extra_weight_attrs.update(
|
extra_weight_attrs.update(
|
||||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
)
|
)
|
||||||
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
self.process_weights_hip_int4(layer)
|
self.process_weights_hip_int4(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
for expert in range(layer.num_experts):
|
for expert in range(layer.num_experts):
|
||||||
if _is_cuda:
|
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||||
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
)
|
||||||
)
|
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||||
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
|
||||||
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
|
||||||
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
|
|||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||||
layer.w13_weight_scale[expert_id][shard_id],
|
layer.w13_weight_scale[expert_id][shard_id],
|
||||||
)
|
)
|
||||||
if _is_cuda:
|
(
|
||||||
(
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
_,
|
||||||
_,
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
|
||||||
else:
|
|
||||||
(
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
|
||||||
_,
|
|
||||||
) = vllm_ops.scaled_fp8_quant(
|
|
||||||
dq_weight, max_w13_scales[expert_id]
|
|
||||||
)
|
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
@@ -930,41 +909,11 @@ class Fp8MoEMethod:
|
|||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
if _is_hip:
|
||||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||||
return ck_moe_2stages_win4(
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
x,
|
return ck_moe_2stages_win4(
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
layer.w13_weight_scale1,
|
|
||||||
layer.w2_weight_scale1,
|
|
||||||
activation=(
|
|
||||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
|
||||||
assert not no_combine, f"{no_combine=} is not supported."
|
|
||||||
if self.block_quant:
|
|
||||||
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
|
||||||
assert (
|
|
||||||
activation == "silu"
|
|
||||||
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
|
||||||
return asm_moe(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
layer.w13_weight_scale_inv,
|
|
||||||
layer.w2_weight_scale_inv,
|
|
||||||
block_shape=tuple(self.quant_config.weight_block_size),
|
|
||||||
expert_mask=None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ck_moe_2stages(
|
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
@@ -978,33 +927,65 @@ class Fp8MoEMethod:
|
|||||||
else ActivationType.Gelu
|
else ActivationType.Gelu
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Expert fusion with FP8 quantization
|
if get_bool_env_var("CK_MOE"):
|
||||||
return fused_experts(
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
x,
|
if self.block_quant:
|
||||||
layer.w13_weight,
|
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||||
layer.w2_weight,
|
assert (
|
||||||
topk_weights=topk_weights,
|
activation == "silu"
|
||||||
topk_ids=topk_ids,
|
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||||
inplace=inplace and not no_combine,
|
return asm_moe(
|
||||||
activation=activation,
|
x,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
layer.w13_weight,
|
||||||
use_fp8_w8a8=True,
|
layer.w2_weight,
|
||||||
w1_scale=(
|
topk_weights,
|
||||||
layer.w13_weight_scale_inv
|
topk_ids,
|
||||||
if self.block_quant
|
layer.w13_weight_scale_inv,
|
||||||
else layer.w13_weight_scale
|
layer.w2_weight_scale_inv,
|
||||||
),
|
block_shape=tuple(self.quant_config.weight_block_size),
|
||||||
w2_scale=(
|
expert_mask=None,
|
||||||
layer.w2_weight_scale_inv
|
)
|
||||||
if self.block_quant
|
else:
|
||||||
else layer.w2_weight_scale
|
return ck_moe_2stages(
|
||||||
),
|
x,
|
||||||
a1_scale=layer.w13_input_scale,
|
layer.w13_weight,
|
||||||
a2_scale=layer.w2_input_scale,
|
layer.w2_weight,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
topk_weights,
|
||||||
no_combine=no_combine,
|
topk_ids,
|
||||||
)
|
layer.w13_weight_scale1,
|
||||||
|
layer.w2_weight_scale1,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu
|
||||||
|
if activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expert fusion with FP8 quantization
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=inplace and not no_combine,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=(
|
||||||
|
layer.w13_weight_scale_inv
|
||||||
|
if self.block_quant
|
||||||
|
else layer.w13_weight_scale
|
||||||
|
),
|
||||||
|
w2_scale=(
|
||||||
|
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||||
|
),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
no_combine=no_combine,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
|||||||
@@ -34,15 +34,23 @@ from sglang.srt.utils import (
|
|||||||
supports_custom_op,
|
supports_custom_op,
|
||||||
)
|
)
|
||||||
|
|
||||||
_enable_jit_deepgemm = False
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
if _is_hip:
|
||||||
|
fp8_max = 224.0
|
||||||
|
else:
|
||||||
|
fp8_max = torch.finfo(_fp8_type).max
|
||||||
|
fp8_min = -fp8_max
|
||||||
|
|
||||||
|
_enable_jit_deepgemm = False
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
from sgl_kernel import (
|
||||||
|
sgl_per_tensor_quant_fp8,
|
||||||
|
sgl_per_token_group_quant_fp8,
|
||||||
|
sgl_per_token_quant_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
sm_version = get_device_sm()
|
sm_version = get_device_sm()
|
||||||
if sm_version == 90 and get_bool_env_var(
|
if sm_version == 90 and get_bool_env_var(
|
||||||
@@ -53,6 +61,7 @@ if _is_cuda:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
|
|
||||||
def deep_gemm_fp8_fp8_bf16_nt(
|
def deep_gemm_fp8_fp8_bf16_nt(
|
||||||
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
dtype: torch.dtype = fp8_type_,
|
|
||||||
column_major_scales: bool = False,
|
column_major_scales: bool = False,
|
||||||
scale_tma_aligned: bool = False,
|
scale_tma_aligned: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
|
|||||||
x: The input tenosr with ndim >= 2.
|
x: The input tenosr with ndim >= 2.
|
||||||
group_size: The group size used for quantization.
|
group_size: The group size used for quantization.
|
||||||
eps: The minimum to avoid dividing zero.
|
eps: The minimum to avoid dividing zero.
|
||||||
dtype: The dype of output tensor.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||||
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
|
|||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
finfo = torch.finfo(dtype)
|
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||||
fp8_max = finfo.max
|
|
||||||
|
|
||||||
if _is_hip:
|
|
||||||
fp8_max = 224.0
|
|
||||||
|
|
||||||
fp8_min = -fp8_max
|
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
||||||
M = x.numel() // group_size
|
M = x.numel() // group_size
|
||||||
N = group_size
|
N = group_size
|
||||||
if column_major_scales:
|
if column_major_scales:
|
||||||
@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
dtype: torch.dtype = fp8_type_,
|
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
x.shape[-1] % group_size == 0
|
x.shape[-1] % group_size == 0
|
||||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
finfo = torch.finfo(dtype)
|
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||||
fp8_max = finfo.max
|
|
||||||
|
|
||||||
fp8_min = -fp8_max
|
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
||||||
M = x.numel() // group_size
|
|
||||||
N = group_size
|
|
||||||
x_s = torch.empty(
|
x_s = torch.empty(
|
||||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
|
|
||||||
def sglang_per_token_quant_fp8(
|
def sglang_per_token_quant_fp8(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
dtype: torch.dtype = fp8_type_,
|
dtype: torch.dtype = _fp8_type,
|
||||||
):
|
):
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
@@ -368,7 +358,6 @@ def static_quant_fp8(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_s: torch.Tensor,
|
x_s: torch.Tensor,
|
||||||
repeat_scale: bool = False,
|
repeat_scale: bool = False,
|
||||||
dtype: torch.dtype = fp8_type_,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
||||||
|
|
||||||
@@ -386,15 +375,8 @@ def static_quant_fp8(
|
|||||||
"""
|
"""
|
||||||
assert x.is_contiguous(), "`x` is not contiguous"
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
fp8_max = finfo.max
|
|
||||||
|
|
||||||
if _is_hip:
|
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||||
fp8_max = 224.0
|
|
||||||
|
|
||||||
fp8_min = -fp8_max
|
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
||||||
M = x.numel() // x.shape[-1]
|
M = x.numel() // x.shape[-1]
|
||||||
N = x.shape[-1]
|
N = x.shape[-1]
|
||||||
if repeat_scale:
|
if repeat_scale:
|
||||||
@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
|
|||||||
|
|
||||||
|
|
||||||
def per_tensor_quant_mla_fp8(
|
def per_tensor_quant_mla_fp8(
|
||||||
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
|
x: torch.Tensor, eps: float = 1e-12
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
This function quantizes input values to float8 values with tensor-wise quantization
|
This function quantizes input values to float8 values with tensor-wise quantization
|
||||||
@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
|
|||||||
"""
|
"""
|
||||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||||
|
|
||||||
finfo = torch.finfo(dtype)
|
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
||||||
fp8_max = finfo.max
|
|
||||||
if _is_hip:
|
|
||||||
dtype = torch.float8_e4m3fnuz
|
|
||||||
fp8_max = 224.0
|
|
||||||
|
|
||||||
x_q = x.new_empty(x.size(), dtype=dtype)
|
|
||||||
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
num_head, num_seq, head_size = x.shape
|
num_head, num_seq, head_size = x.shape
|
||||||
@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
|
|||||||
head_size,
|
head_size,
|
||||||
x.stride(0),
|
x.stride(0),
|
||||||
x.stride(1),
|
x.stride(1),
|
||||||
-fp8_max,
|
fp8_min,
|
||||||
fp8_max,
|
fp8_max,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_fp8_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
num_token_padding: Optional[int] = None,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): Input tensor to be quantized
|
||||||
|
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||||
|
If None, scales will be computed dynamically.
|
||||||
|
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||||
|
of the output to at least this value.
|
||||||
|
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||||
|
determines the quantization granularity:
|
||||||
|
- True: compute scale per token
|
||||||
|
- False: compute single scale per tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||||
|
- quantized_tensor: The FP8 quantized version of input
|
||||||
|
- scale_tensor: The scaling factors used for quantization
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||||
|
"""
|
||||||
|
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||||
|
shape = input.shape
|
||||||
|
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
if num_token_padding:
|
||||||
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
# Dynamic scaling
|
||||||
|
if use_per_token_if_dynamic:
|
||||||
|
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
||||||
|
sgl_per_token_quant_fp8(input, output, scale)
|
||||||
|
else:
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
sgl_per_tensor_quant_fp8(
|
||||||
|
input, output, scale, is_static=False
|
||||||
|
) # False for dynamic
|
||||||
|
else:
|
||||||
|
# Static scaling
|
||||||
|
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
||||||
|
sgl_per_tensor_quant_fp8(
|
||||||
|
input, output, scale, is_static=True
|
||||||
|
) # True for static
|
||||||
|
|
||||||
|
return output, scale
|
||||||
|
|||||||
@@ -1,11 +1,19 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
_enable_jit_deepgemm,
|
_enable_jit_deepgemm,
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
|
scaled_fp8_quant,
|
||||||
|
sglang_per_token_quant_fp8,
|
||||||
static_quant_fp8,
|
static_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
@@ -17,30 +25,20 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
import vllm
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
VLLM_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
VLLM_AVAILABLE = False
|
|
||||||
|
|
||||||
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||||
from aiter import gemm_a8w8_blockscale
|
from aiter import gemm_a8w8_blockscale
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
||||||
|
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
TORCH_DEVICE_IDENTITY = None
|
||||||
|
|
||||||
_TORCH_VERSION = torch.__version__.split("+")[0]
|
_TORCH_VERSION = torch.__version__.split("+")[0]
|
||||||
try:
|
try:
|
||||||
@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
|
|||||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||||
|
|
||||||
x_q_tensor, scale = (
|
x_q_tensor, scale = (
|
||||||
sgl_scaled_fp8_quant(x_dq_block)
|
scaled_fp8_quant(x_dq_block)
|
||||||
if _is_cuda
|
if _is_cuda
|
||||||
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||||
)
|
)
|
||||||
@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
||||||
x_q_tensor, scale = (
|
x_q_tensor, scale = (
|
||||||
sgl_scaled_fp8_quant(x_dq_channel)
|
scaled_fp8_quant(x_dq_channel)
|
||||||
if _is_cuda
|
if _is_cuda
|
||||||
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||||
)
|
)
|
||||||
@@ -264,7 +262,7 @@ def apply_fp8_linear(
|
|||||||
# final solution should be: 1. add support to per-tensor activation scaling.
|
# final solution should be: 1. add support to per-tensor activation scaling.
|
||||||
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
||||||
if _is_hip and weight_scale.numel() == 1:
|
if _is_hip and weight_scale.numel() == 1:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
@@ -275,32 +273,29 @@ def apply_fp8_linear(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cutlass_fp8_supported:
|
if cutlass_fp8_supported:
|
||||||
try:
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
output = vllm_ops.cutlass_scaled_mm(
|
||||||
output = ops.cutlass_scaled_mm(
|
qinput,
|
||||||
qinput,
|
weight,
|
||||||
weight,
|
out_dtype=input.dtype,
|
||||||
out_dtype=input.dtype,
|
scale_a=x_scale,
|
||||||
scale_a=x_scale,
|
scale_b=weight_scale,
|
||||||
scale_b=weight_scale,
|
bias=bias,
|
||||||
bias=bias,
|
)
|
||||||
)
|
else:
|
||||||
else:
|
assert (
|
||||||
assert (
|
weight_scale.numel() == weight.shape[1]
|
||||||
weight_scale.numel() == weight.shape[1]
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
output = fp8_scaled_mm(
|
||||||
output = fp8_scaled_mm(
|
qinput,
|
||||||
qinput,
|
weight,
|
||||||
weight,
|
x_scale,
|
||||||
x_scale,
|
weight_scale,
|
||||||
weight_scale,
|
out_dtype=input.dtype,
|
||||||
out_dtype=input.dtype,
|
bias=bias,
|
||||||
bias=bias,
|
)
|
||||||
)
|
return output.view(*output_shape)
|
||||||
return output.view(*output_shape)
|
|
||||||
except (ImportError, NameError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# torch.scaled_mm supports per tensor weights + activations only
|
# torch.scaled_mm supports per tensor weights + activations only
|
||||||
# so fallback to naive if per channel or per token
|
# so fallback to naive if per channel or per token
|
||||||
@@ -343,8 +338,10 @@ def apply_fp8_linear(
|
|||||||
|
|
||||||
# Making sure the dummy tensor is on the same device as the weight
|
# Making sure the dummy tensor is on the same device as the weight
|
||||||
global TORCH_DEVICE_IDENTITY
|
global TORCH_DEVICE_IDENTITY
|
||||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
if TORCH_DEVICE_IDENTITY is None:
|
||||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
TORCH_DEVICE_IDENTITY = torch.ones(
|
||||||
|
1, dtype=torch.float32, device=weight.device
|
||||||
|
)
|
||||||
|
|
||||||
# GEMM
|
# GEMM
|
||||||
# This computes C = (X * W).
|
# This computes C = (X * W).
|
||||||
@@ -372,13 +369,6 @@ def apply_fp8_linear(
|
|||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
def maybe_create_device_identity():
|
|
||||||
# Allocate dummy ones tensor for torch._scaled_mm
|
|
||||||
global TORCH_DEVICE_IDENTITY
|
|
||||||
if TORCH_DEVICE_IDENTITY is None:
|
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
||||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
@@ -405,9 +395,7 @@ class Fp8LinearOp:
|
|||||||
# We also don't pad when using torch.compile,
|
# We also don't pad when using torch.compile,
|
||||||
# as it breaks with dynamic shapes.
|
# as it breaks with dynamic shapes.
|
||||||
if pad_output is None:
|
if pad_output is None:
|
||||||
enable_torch_compile = os.environ.get(
|
enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
||||||
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
|
||||||
).lower() in ("1", "true", "yes")
|
|
||||||
pad_output = not enable_torch_compile
|
pad_output = not enable_torch_compile
|
||||||
self.output_padding = 17 if pad_output else None
|
self.output_padding = 17 if pad_output else None
|
||||||
|
|
||||||
@@ -439,13 +427,13 @@ class Fp8LinearOp:
|
|||||||
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
||||||
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
qinput, x_scale = scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
scale_ub=input_scale_ub,
|
scale_ub=input_scale_ub,
|
||||||
@@ -455,7 +443,7 @@ class Fp8LinearOp:
|
|||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||||
output = ops.cutlass_scaled_mm(
|
output = vllm_ops.cutlass_scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
weight,
|
weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
@@ -482,14 +470,14 @@ class Fp8LinearOp:
|
|||||||
else:
|
else:
|
||||||
# Maybe apply padding to output, see comment in __init__
|
# Maybe apply padding to output, see comment in __init__
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
qinput, x_scale = scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
num_token_padding=self.output_padding,
|
num_token_padding=self.output_padding,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
num_token_padding=self.output_padding,
|
num_token_padding=self.output_padding,
|
||||||
@@ -562,9 +550,12 @@ class Fp8LinearOp:
|
|||||||
# This computes C = (X * W).
|
# This computes C = (X * W).
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
|
|
||||||
|
# Making sure the dummy tensor is on the same device as the weight
|
||||||
global TORCH_DEVICE_IDENTITY
|
global TORCH_DEVICE_IDENTITY
|
||||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
if TORCH_DEVICE_IDENTITY is None:
|
||||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
TORCH_DEVICE_IDENTITY = torch.ones(
|
||||||
|
1, dtype=torch.float32, device=weight.device
|
||||||
|
)
|
||||||
|
|
||||||
output = torch._scaled_mm(
|
output = torch._scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||||
|
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import List, Mapping, Optional, Tuple, Union
|
from typing import List, Mapping, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_cuda:
|
if not _is_cuda:
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
else:
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
|
|
||||||
def is_fp8_fnuz() -> bool:
|
def is_fp8_fnuz() -> bool:
|
||||||
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
|
|||||||
for idx, logical_width in enumerate(logical_widths):
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
end = start + logical_width
|
end = start + logical_width
|
||||||
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
||||||
if _is_cuda:
|
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
|
||||||
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
|
||||||
else:
|
|
||||||
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
|
||||||
weight_dq, max_w_scale
|
|
||||||
)
|
|
||||||
start = end
|
start = end
|
||||||
|
|
||||||
return max_w_scale, weight
|
return max_w_scale, weight
|
||||||
|
|||||||
@@ -1,13 +1,6 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
|
||||||
|
|
||||||
is_cuda = is_cuda_available()
|
|
||||||
if is_cuda:
|
|
||||||
from sgl_kernel import int8_scaled_mm
|
|
||||||
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
|
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||||
|
|
||||||
|
is_cuda = is_cuda_available()
|
||||||
|
if is_cuda:
|
||||||
|
from sgl_kernel import int8_scaled_mm
|
||||||
|
|
||||||
|
|
||||||
class W8A8Int8Config(QuantizationConfig):
|
class W8A8Int8Config(QuantizationConfig):
|
||||||
|
|||||||
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
|
|||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
_is_cuda_available = is_cuda_available()
|
_is_cuda_available = is_cuda_available()
|
||||||
|
|
||||||
if _is_cuda_available:
|
if _is_cuda_available:
|
||||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as ops
|
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
ops.rotary_embedding(
|
vllm_rotary_embedding(
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
|
||||||
|
|
||||||
|
|
||||||
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
||||||
"""
|
|
||||||
Get corresponding backend class from backend's name
|
|
||||||
"""
|
|
||||||
if name == "triton":
|
|
||||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
|
||||||
|
|
||||||
return TritonLoRABackend
|
|
||||||
elif name == "flashinfer":
|
|
||||||
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
|
||||||
|
|
||||||
return FlashInferLoRABackend
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid backend: {name}")
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseLoRABackend",
|
|
||||||
"FlashInferLoRABackend",
|
|
||||||
"TritonLoRABackend",
|
|
||||||
"get_backend_from_name",
|
|
||||||
]
|
|
||||||
@@ -75,7 +75,7 @@ class BaseLoRABackend:
|
|||||||
qkv_lora_a: torch.Tensor,
|
qkv_lora_a: torch.Tensor,
|
||||||
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Run the lora pass for QKV Layer.
|
"""Run the lora pass for QKV Layer.
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ class BaseLoRABackend:
|
|||||||
gate_up_lora_a: torch.Tensor,
|
gate_up_lora_a: torch.Tensor,
|
||||||
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
||||||
|
|
||||||
@@ -115,3 +115,19 @@ class BaseLoRABackend:
|
|||||||
|
|
||||||
def set_batch_info(self, batch_info: LoRABatchInfo):
|
def set_batch_info(self, batch_info: LoRABatchInfo):
|
||||||
self.batch_info = batch_info
|
self.batch_info = batch_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||||
|
"""
|
||||||
|
Get corresponding backend class from backend's name
|
||||||
|
"""
|
||||||
|
if name == "triton":
|
||||||
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||||
|
|
||||||
|
return TritonLoRABackend
|
||||||
|
elif name == "flashinfer":
|
||||||
|
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
||||||
|
|
||||||
|
return FlashInferLoRABackend
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid backend: {name}")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.lora.backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
from sglang.srt.lora.triton_ops import (
|
from sglang.srt.lora.triton_ops import (
|
||||||
gate_up_lora_b_fwd,
|
gate_up_lora_b_fwd,
|
||||||
qkv_lora_b_fwd,
|
qkv_lora_b_fwd,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.lora.backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
|
|
||||||
|
|
||||||
class BaseLayerWithLoRA(nn.Module):
|
class BaseLayerWithLoRA(nn.Module):
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from torch import nn
|
|||||||
|
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
from sglang.srt.lora.backend import BaseLoRABackend
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
|
||||||
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
||||||
from sglang.srt.lora.lora import LoRAAdapter
|
from sglang.srt.lora.lora import LoRAAdapter
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Multi-modality utils
|
Multi-modality utils
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
MultimodalDataItem,
|
MultimodalDataItem,
|
||||||
MultimodalInputs,
|
MultimodalInputs,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
logger,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import print_warning_once
|
from sglang.srt.utils import print_warning_once
|
||||||
from sglang.utils import logger
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalityDataPaddingPattern:
|
class MultiModalityDataPaddingPattern:
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ import logging
|
|||||||
import pkgutil
|
import pkgutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from transformers import PROCESSOR_MAPPING
|
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
from decord import VideoReader, cpu
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import BaseImageProcessorFast
|
from transformers import BaseImageProcessorFast
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import Modality
|
from sglang.srt.managers.schedule_batch import Modality
|
||||||
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
"""
|
"""
|
||||||
estimate the total frame count from all visual input
|
estimate the total frame count from all visual input
|
||||||
"""
|
"""
|
||||||
|
# Lazy import because decord is not available on some arm platforms.
|
||||||
|
from decord import VideoReader, cpu
|
||||||
|
|
||||||
# Before processing inputs
|
# Before processing inputs
|
||||||
estimated_frames_list = []
|
estimated_frames_list = []
|
||||||
for image in image_data:
|
for image in image_data:
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||||
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
|
|||||||
@@ -320,7 +320,6 @@ class ModelRunner:
|
|||||||
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
||||||
|
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
logger.info("Disable chunked prefix cache for non-MLA backend.")
|
|
||||||
server_args.disable_chunked_prefix_cache = True
|
server_args.disable_chunked_prefix_cache = True
|
||||||
elif self.page_size > 1:
|
elif self.page_size > 1:
|
||||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ _is_cuda = is_cuda()
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import awq_dequantize
|
from sgl_kernel import awq_dequantize
|
||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as ops
|
from vllm._custom_ops import awq_dequantize
|
||||||
|
|
||||||
|
|
||||||
class DeepseekModelNextN(nn.Module):
|
class DeepseekModelNextN(nn.Module):
|
||||||
@@ -273,7 +273,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
self_attn.kv_b_proj.qzeros,
|
self_attn.kv_b_proj.qzeros,
|
||||||
).T
|
).T
|
||||||
else:
|
else:
|
||||||
w = ops.awq_dequantize(
|
w = awq_dequantize(
|
||||||
self_attn.kv_b_proj.qweight,
|
self_attn.kv_b_proj.qweight,
|
||||||
self_attn.kv_b_proj.scales,
|
self_attn.kv_b_proj.scales,
|
||||||
self_attn.kv_b_proj.qzeros,
|
self_attn.kv_b_proj.qzeros,
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ from sglang.srt.layers.linear import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||||
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
@@ -80,10 +81,8 @@ _is_cuda = is_cuda()
|
|||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
|
||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as ops
|
from vllm._custom_ops import awq_dequantize
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||||
@@ -861,7 +860,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
q_nope.transpose(0, 1),
|
||||||
)
|
)
|
||||||
q_nope_out = bmm_fp8(
|
q_nope_out = bmm_fp8(
|
||||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||||
@@ -892,7 +891,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
attn_output.transpose(0, 1),
|
||||||
)
|
)
|
||||||
attn_bmm_output = bmm_fp8(
|
attn_bmm_output = bmm_fp8(
|
||||||
attn_output_val,
|
attn_output_val,
|
||||||
@@ -1565,7 +1564,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self_attn.kv_b_proj.qzeros,
|
self_attn.kv_b_proj.qzeros,
|
||||||
).T
|
).T
|
||||||
else:
|
else:
|
||||||
w = ops.awq_dequantize(
|
w = awq_dequantize(
|
||||||
self_attn.kv_b_proj.qweight,
|
self_attn.kv_b_proj.qweight,
|
||||||
self_attn.kv_b_proj.scales,
|
self_attn.kv_b_proj.scales,
|
||||||
self_attn.kv_b_proj.qzeros,
|
self_attn.kv_b_proj.qzeros,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import re
|
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,11 @@ import torch
|
|||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SamplingBatchInfo:
|
class SamplingBatchInfo:
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
# Copyright 2023-2024 SGLang Team
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
# Some shortcuts for backward compatibility.
|
|
||||||
# They will be removed in new versions.
|
|
||||||
from sglang.srt.entrypoints.engine import Engine
|
|
||||||
from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
|
|
||||||
@@ -187,6 +187,7 @@ class ServerArgs:
|
|||||||
n_share_experts_fusion: int = 0
|
n_share_experts_fusion: int = 0
|
||||||
disable_shared_experts_fusion: bool = False
|
disable_shared_experts_fusion: bool = False
|
||||||
disable_chunked_prefix_cache: bool = False
|
disable_chunked_prefix_cache: bool = False
|
||||||
|
disable_fast_image_processor: bool = False
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
@@ -198,9 +199,6 @@ class ServerArgs:
|
|||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
disaggregation_transfer_backend: str = "mooncake"
|
disaggregation_transfer_backend: str = "mooncake"
|
||||||
|
|
||||||
# multimodal
|
|
||||||
disable_fast_image_processor: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
if self.enable_ep_moe:
|
if self.enable_ep_moe:
|
||||||
@@ -1136,6 +1134,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
|
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-fast-image-processor",
|
||||||
|
action="store_true",
|
||||||
|
help="Adopt base image processor instead of fast image processor.",
|
||||||
|
)
|
||||||
|
|
||||||
# Server warmups
|
# Server warmups
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -1187,13 +1190,6 @@ class ServerArgs:
|
|||||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Multimodal
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-fast-image-processor",
|
|
||||||
action="store_true",
|
|
||||||
help="Adopt base image processor instead of fast image processor.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ import torch.distributed
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import triton
|
import triton
|
||||||
import zmq
|
import zmq
|
||||||
from decord import VideoReader, cpu
|
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
|||||||
|
|
||||||
|
|
||||||
def encode_video(video_path, frame_count_limit=None):
|
def encode_video(video_path, frame_count_limit=None):
|
||||||
|
# Lazy import because decord is not available on some arm platforms.
|
||||||
|
from decord import VideoReader, cpu
|
||||||
|
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
logger.error(f"Video {video_path} does not exist")
|
logger.error(f"Video {video_path} does not exist")
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ from transformers import (
|
|||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.server import Engine
|
|
||||||
from sglang.srt.utils import load_image
|
from sglang.srt.utils import load_image
|
||||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.utils import is_cuda
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -93,9 +93,7 @@ class TestPerTokenGroupQuantFP8(TestFP8Base):
|
|||||||
A, A_quant_gt, scale_gt = self._make_A(
|
A, A_quant_gt, scale_gt = self._make_A(
|
||||||
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
|
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
|
||||||
)
|
)
|
||||||
A_quant, scale = per_token_group_quant_fp8(
|
A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size)
|
||||||
x=A, group_size=self.group_size, dtype=self.quant_type
|
|
||||||
)
|
|
||||||
torch.testing.assert_close(scale, scale_gt)
|
torch.testing.assert_close(scale, scale_gt)
|
||||||
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
|
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
|
||||||
diff_count = (diff > 1e-5).count_nonzero()
|
diff_count = (diff > 1e-5).count_nonzero()
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
|||||||
|
|
||||||
B, D = a.shape
|
B, D = a.shape
|
||||||
# Perform per-token quantization
|
# Perform per-token quantization
|
||||||
a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||||
# Repeat tokens to match topk
|
# Repeat tokens to match topk
|
||||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||||
# Also repeat the scale
|
# Also repeat the scale
|
||||||
@@ -69,7 +69,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
|||||||
# Activation function
|
# Activation function
|
||||||
act_out = SiluAndMul().forward_native(inter_out)
|
act_out = SiluAndMul().forward_native(inter_out)
|
||||||
# Quantize activation output with per-token
|
# Quantize activation output with per-token
|
||||||
act_out_q, act_out_s = sgl_scaled_fp8_quant(
|
act_out_q, act_out_s = scaled_fp8_quant(
|
||||||
act_out, use_per_token_if_dynamic=True
|
act_out, use_per_token_if_dynamic=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user