Clean up imports (#5467)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.api import (
|
||||
user_end,
|
||||
video,
|
||||
)
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.lang.choices import (
|
||||
greedy_token_selection,
|
||||
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
|
||||
unconditional_likelihood_normalized,
|
||||
)
|
||||
from sglang.utils import LazyImport
|
||||
from sglang.version import __version__
|
||||
|
||||
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
||||
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
||||
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
||||
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
||||
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
||||
|
||||
# Other configs
|
||||
from sglang.global_config import global_config
|
||||
from sglang.version import __version__
|
||||
|
||||
__all__ = [
|
||||
"Engine",
|
||||
"Runtime",
|
||||
|
||||
@@ -707,10 +707,6 @@ def sample_random_requests(
|
||||
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(
|
||||
"If you do not want to randomly sample from a dataset,"
|
||||
" please use --dataset-name random-ids."
|
||||
)
|
||||
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||
|
||||
# Load the dataset.
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
|
||||
@@ -2,7 +2,7 @@ import dataclasses
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
|
||||
@@ -5,13 +5,7 @@ from typing import List, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglSamplingParams,
|
||||
SglVariable,
|
||||
)
|
||||
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
|
||||
|
||||
|
||||
def compile_func(function, backend):
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
"""Tracing a program."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
SglCommitLazy,
|
||||
SglConcateAndAppend,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFork,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglGetForkItem,
|
||||
SglRoleBegin,
|
||||
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
|
||||
self.cur_role = None
|
||||
|
||||
def _execute_var_scope_end(self, expr: SglVarScopeEnd):
|
||||
new_node = SglVariable(name, source=self.last_node)
|
||||
self.variables[name] = new_node
|
||||
new_node = SglVariable(expr.name, source=self.last_node)
|
||||
self.variables[expr.name] = new_node
|
||||
|
||||
def get_var(self, name):
|
||||
ret = self.arguments.get(name, None)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
|
||||
|
||||
|
||||
@@ -42,65 +42,3 @@ class CustomOp(nn.Module):
|
||||
return self.forward_hip
|
||||
else:
|
||||
return self.forward_native
|
||||
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
- False: compute single scale per tensor
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- quantized_tensor: The FP8 quantized version of input
|
||||
- scale_tensor: The scaling factors used for quantization
|
||||
|
||||
Raises:
|
||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||
"""
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=False
|
||||
) # False for dynamic
|
||||
else:
|
||||
# Static scaling
|
||||
assert (
|
||||
scale.numel() == 1
|
||||
), f"Expected scalar scale, got numel={scale.numel()}"
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=True
|
||||
) # True for static
|
||||
|
||||
return output, scale
|
||||
|
||||
@@ -19,11 +19,10 @@ import torch.distributed as dist
|
||||
from PIL.Image import Image
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor
|
||||
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
|
||||
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
||||
|
||||
|
||||
|
||||
@@ -21,13 +21,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
_is_cuda = is_cuda_available()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
divide,
|
||||
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda_available()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
_is_cuda = is_cuda_available()
|
||||
@@ -31,7 +32,6 @@ if _is_cuda:
|
||||
rmsnorm,
|
||||
)
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
try:
|
||||
from deep_gemm import (
|
||||
@@ -13,8 +14,6 @@ try:
|
||||
except ImportError:
|
||||
use_deep_gemm = False
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -37,21 +36,16 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
_buffer = None
|
||||
if _is_hip:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
)
|
||||
|
||||
for expert in range(layer.num_experts_per_partition):
|
||||
if _is_cuda:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
else:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
@@ -13,6 +13,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||
|
||||
enable_moe_align_block_size_triton = bool(
|
||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||
enable_moe_align_block_size_triton = bool(
|
||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_zeros_to_output(
|
||||
c_ptr,
|
||||
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
|
||||
# activation tensor-wise fp8 quantization, dynamic or static
|
||||
padded_size = padding_size
|
||||
# activations apply per-token quantization when weights apply per-channel quantization by default
|
||||
if _is_cuda:
|
||||
A, A_scale = sgl_scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
||||
)
|
||||
else:
|
||||
A, A_scale = vllm_ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
||||
)
|
||||
A, A_scale = scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
||||
)
|
||||
else:
|
||||
# activation block-wise fp8 quantization
|
||||
assert len(block_shape) == 2
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# ==============================================================================
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@@ -29,6 +28,10 @@ _is_hip = is_hip()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
@@ -59,11 +62,6 @@ def fused_topk(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
@@ -76,20 +74,12 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if _is_cuda or _is_hip:
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
else:
|
||||
vllm_ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
del token_expert_indicies
|
||||
|
||||
if renormalize:
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Callable, Optional, Union
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter",
|
||||
"PackedvLLMParameter",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
is_activation_quantization_format,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
if not _is_cuda:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
try:
|
||||
import vllm
|
||||
@@ -58,8 +51,6 @@ __all__ = [
|
||||
|
||||
class CompressedTensorsMoEMethod:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if cls is CompressedTensorsMoEMethod:
|
||||
return super().__new__(cls)
|
||||
return super().__new__(cls)
|
||||
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
|
||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
|
||||
)
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
|
||||
if _is_cuda:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
else:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = vllm_ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id]
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
@@ -345,11 +323,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -609,7 +582,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||
layer.w13_weight_packed,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||
@@ -617,7 +590,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||
@@ -661,14 +634,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use fused_marlin_moe, please install vllm"
|
||||
)
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
||||
|
||||
@@ -17,7 +17,6 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
||||
@@ -99,8 +98,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
|
||||
@@ -8,15 +8,6 @@ import torch.nn.functional as F
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
convert_to_channelwise,
|
||||
is_layer_skipped,
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
@@ -27,11 +18,12 @@ try:
|
||||
except ImportError:
|
||||
MARLIN_FP8_AVAILABLE = False
|
||||
|
||||
def apply_fp8_marlin_linear(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
def dummy_func(*args, **kwargs):
|
||||
raise ImportError(
|
||||
"marlin FP8 requires some operators from vllm. Please install vllm."
|
||||
)
|
||||
|
||||
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
|
||||
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
convert_to_channelwise,
|
||||
is_layer_skipped,
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
permute_weight,
|
||||
print_warning_once,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_hip:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if not _is_cuda:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale_inv.data, requires_grad=False
|
||||
)
|
||||
return
|
||||
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
try:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
except ImportError:
|
||||
self.use_marlin = False
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
try:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
except ImportError:
|
||||
self.use_marlin = False
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# INT4 MoE weight - INT32 packed
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
self.process_weights_hip_int4(layer)
|
||||
return
|
||||
|
||||
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.num_experts):
|
||||
if _is_cuda:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
else:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
if _is_cuda:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
else:
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = vllm_ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id]
|
||||
)
|
||||
(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
_,
|
||||
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
@@ -930,41 +909,11 @@ class Fp8MoEMethod:
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return ck_moe_2stages_win4(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
block_shape=tuple(self.quant_config.weight_block_size),
|
||||
expert_mask=None,
|
||||
)
|
||||
else:
|
||||
return ck_moe_2stages(
|
||||
if _is_hip:
|
||||
if get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return ck_moe_2stages_win4(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -978,33 +927,65 @@ class Fp8MoEMethod:
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
if get_bool_env_var("CK_MOE"):
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
if self.block_quant:
|
||||
# TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
|
||||
return asm_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
block_shape=tuple(self.quant_config.weight_block_size),
|
||||
expert_mask=None,
|
||||
)
|
||||
else:
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
layer.w13_weight_scale1,
|
||||
layer.w2_weight_scale1,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
)
|
||||
|
||||
# Expert fusion with FP8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
layer.w13_weight_scale_inv
|
||||
if self.block_quant
|
||||
else layer.w13_weight_scale
|
||||
),
|
||||
w2_scale=(
|
||||
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
||||
),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
@@ -34,15 +34,23 @@ from sglang.srt.utils import (
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
_enable_jit_deepgemm = False
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
else:
|
||||
fp8_max = torch.finfo(_fp8_type).max
|
||||
fp8_min = -fp8_max
|
||||
|
||||
_enable_jit_deepgemm = False
|
||||
if _is_cuda:
|
||||
import deep_gemm
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
from sgl_kernel import (
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
|
||||
sm_version = get_device_sm()
|
||||
if sm_version == 90 and get_bool_env_var(
|
||||
@@ -53,6 +61,7 @@ if _is_cuda:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt(
|
||||
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
|
||||
x: The input tenosr with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
if column_major_scales:
|
||||
@@ -276,27 +275,18 @@ def sglang_per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
):
|
||||
assert (
|
||||
x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
x_s = torch.empty(
|
||||
x.shape[:-1] + (x.shape[-1] // group_size,),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
||||
|
||||
return x_q, x_s
|
||||
@@ -304,7 +294,7 @@ def sglang_per_token_group_quant_fp8(
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
dtype: torch.dtype = _fp8_type,
|
||||
):
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
@@ -368,7 +358,6 @@ def static_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
repeat_scale: bool = False,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
||||
|
||||
@@ -386,15 +375,8 @@ def static_quant_fp8(
|
||||
"""
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
if _is_hip:
|
||||
fp8_max = 224.0
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
|
||||
M = x.numel() // x.shape[-1]
|
||||
N = x.shape[-1]
|
||||
if repeat_scale:
|
||||
@@ -896,7 +878,7 @@ def _per_tensor_quant_mla_fp8_stage2(
|
||||
|
||||
|
||||
def per_tensor_quant_mla_fp8(
|
||||
x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
x: torch.Tensor, eps: float = 1e-12
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function quantizes input values to float8 values with tensor-wise quantization
|
||||
@@ -904,13 +886,7 @@ def per_tensor_quant_mla_fp8(
|
||||
"""
|
||||
assert x.dim() == 3, "`x` is not a 3d-tensor"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
|
||||
x_q = x.new_empty(x.size(), dtype=dtype)
|
||||
x_q = x.new_empty(x.size(), dtype=_fp8_type)
|
||||
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
|
||||
|
||||
num_head, num_seq, head_size = x.shape
|
||||
@@ -935,9 +911,64 @@ def per_tensor_quant_mla_fp8(
|
||||
head_size,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
-fp8_max,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
- False: compute single scale per tensor
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- quantized_tensor: The FP8 quantized version of input
|
||||
- scale_tensor: The scaling factors used for quantization
|
||||
|
||||
Raises:
|
||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||
"""
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=False
|
||||
) # False for dynamic
|
||||
else:
|
||||
# Static scaling
|
||||
assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=True
|
||||
) # True for static
|
||||
|
||||
return output, scale
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_enable_jit_deepgemm,
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
@@ -17,30 +25,20 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
)
|
||||
|
||||
try:
|
||||
import vllm
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
||||
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
_TORCH_VERSION = torch.__version__.split("+")[0]
|
||||
try:
|
||||
@@ -214,7 +212,7 @@ def block_quant_to_tensor_quant(
|
||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_block)
|
||||
scaled_fp8_quant(x_dq_block)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
)
|
||||
@@ -227,7 +225,7 @@ def channel_quant_to_tensor_quant(
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x_dq_channel = x_q_channel.to(torch.float32) * x_s
|
||||
x_q_tensor, scale = (
|
||||
sgl_scaled_fp8_quant(x_dq_channel)
|
||||
scaled_fp8_quant(x_dq_channel)
|
||||
if _is_cuda
|
||||
else input_to_float8(x_dq_channel, dtype=x_q_channel.dtype)
|
||||
)
|
||||
@@ -264,7 +262,7 @@ def apply_fp8_linear(
|
||||
# final solution should be: 1. add support to per-tensor activation scaling.
|
||||
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
||||
if _is_hip and weight_scale.numel() == 1:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
@@ -275,32 +273,29 @@ def apply_fp8_linear(
|
||||
)
|
||||
|
||||
if cutlass_fp8_supported:
|
||||
try:
|
||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
weight_scale.numel() == weight.shape[1]
|
||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||
output = fp8_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
except (ImportError, NameError, AttributeError):
|
||||
pass
|
||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = vllm_ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
weight_scale.numel() == weight.shape[1]
|
||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||
output = fp8_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
@@ -343,8 +338,10 @@ def apply_fp8_linear(
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(
|
||||
1, dtype=torch.float32, device=weight.device
|
||||
)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
@@ -372,13 +369,6 @@ def apply_fp8_linear(
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
@@ -405,9 +395,7 @@ class Fp8LinearOp:
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
if pad_output is None:
|
||||
enable_torch_compile = os.environ.get(
|
||||
"SGLANG_ENABLE_TORCH_COMPILE", "0"
|
||||
).lower() in ("1", "true", "yes")
|
||||
enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
||||
pad_output = not enable_torch_compile
|
||||
self.output_padding = 17 if pad_output else None
|
||||
|
||||
@@ -439,13 +427,13 @@ class Fp8LinearOp:
|
||||
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
||||
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
||||
if _is_cuda:
|
||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
||||
qinput, x_scale = scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
scale_ub=input_scale_ub,
|
||||
@@ -455,7 +443,7 @@ class Fp8LinearOp:
|
||||
# Fused GEMM_DQ
|
||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = ops.cutlass_scaled_mm(
|
||||
output = vllm_ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
@@ -482,14 +470,14 @@ class Fp8LinearOp:
|
||||
else:
|
||||
# Maybe apply padding to output, see comment in __init__
|
||||
if _is_cuda:
|
||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
||||
qinput, x_scale = scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=self.output_padding,
|
||||
@@ -562,9 +550,12 @@ class Fp8LinearOp:
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(
|
||||
1, dtype=torch.float32, device=weight.device
|
||||
)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import List, Mapping, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
if not _is_cuda:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
def is_fp8_fnuz() -> bool:
|
||||
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
||||
if _is_cuda:
|
||||
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
||||
else:
|
||||
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
||||
weight_dq, max_w_scale
|
||||
)
|
||||
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
|
||||
start = end
|
||||
|
||||
return max_w_scale, weight
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||
|
||||
is_cuda = is_cuda_available()
|
||||
if is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
||||
|
||||
is_cuda = is_cuda_available()
|
||||
if is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
|
||||
|
||||
class W8A8Int8Config(QuantizationConfig):
|
||||
|
||||
@@ -11,10 +11,11 @@ from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
_is_cuda_available = is_cuda_available()
|
||||
|
||||
if _is_cuda_available:
|
||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
|
||||
)
|
||||
else:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
ops.rotary_embedding(
|
||||
vllm_rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
|
||||
|
||||
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||
"""
|
||||
Get corresponding backend class from backend's name
|
||||
"""
|
||||
if name == "triton":
|
||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||
|
||||
return TritonLoRABackend
|
||||
elif name == "flashinfer":
|
||||
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
||||
|
||||
return FlashInferLoRABackend
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {name}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseLoRABackend",
|
||||
"FlashInferLoRABackend",
|
||||
"TritonLoRABackend",
|
||||
"get_backend_from_name",
|
||||
]
|
||||
@@ -75,7 +75,7 @@ class BaseLoRABackend:
|
||||
qkv_lora_a: torch.Tensor,
|
||||
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Run the lora pass for QKV Layer.
|
||||
|
||||
@@ -98,7 +98,7 @@ class BaseLoRABackend:
|
||||
gate_up_lora_a: torch.Tensor,
|
||||
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
*args,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
||||
|
||||
@@ -115,3 +115,19 @@ class BaseLoRABackend:
|
||||
|
||||
def set_batch_info(self, batch_info: LoRABatchInfo):
|
||||
self.batch_info = batch_info
|
||||
|
||||
|
||||
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
||||
"""
|
||||
Get corresponding backend class from backend's name
|
||||
"""
|
||||
if name == "triton":
|
||||
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
||||
|
||||
return TritonLoRABackend
|
||||
elif name == "flashinfer":
|
||||
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
||||
|
||||
return FlashInferLoRABackend
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {name}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
from sglang.srt.lora.utils import LoRABatchInfo
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
from sglang.srt.lora.triton_ops import (
|
||||
gate_up_lora_b_fwd,
|
||||
qkv_lora_b_fwd,
|
||||
|
||||
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
@@ -27,7 +27,7 @@ from torch import nn
|
||||
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.backend import BaseLoRABackend
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
||||
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
|
||||
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
Multi-modality utils
|
||||
Multi-modality utils
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
global_server_args_dict,
|
||||
logger,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import print_warning_once
|
||||
from sglang.utils import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPattern:
|
||||
|
||||
@@ -5,8 +5,6 @@ import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
from transformers import PROCESSOR_MAPPING
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
|
||||
@@ -8,8 +8,6 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality
|
||||
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
"""
|
||||
estimate the total frame count from all visual input
|
||||
"""
|
||||
# Lazy import because decord is not available on some arm platforms.
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
# Before processing inputs
|
||||
estimated_frames_list = []
|
||||
for image in image_data:
|
||||
|
||||
@@ -37,11 +37,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
|
||||
@@ -320,7 +320,6 @@ class ModelRunner:
|
||||
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
||||
|
||||
if not self.use_mla_backend:
|
||||
logger.info("Disable chunked prefix cache for non-MLA backend.")
|
||||
server_args.disable_chunked_prefix_cache = True
|
||||
elif self.page_size > 1:
|
||||
logger.info("Disable chunked prefix cache when page size > 1.")
|
||||
|
||||
@@ -48,7 +48,7 @@ _is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
|
||||
class DeepseekModelNextN(nn.Module):
|
||||
@@ -273,7 +273,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
).T
|
||||
else:
|
||||
w = ops.awq_dequantize(
|
||||
w = awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -80,10 +81,8 @@ _is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
if _is_hip:
|
||||
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
||||
@@ -861,7 +860,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
q_nope.transpose(0, 1),
|
||||
)
|
||||
q_nope_out = bmm_fp8(
|
||||
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
||||
@@ -892,7 +891,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
|
||||
attn_output.transpose(0, 1),
|
||||
)
|
||||
attn_bmm_output = bmm_fp8(
|
||||
attn_output_val,
|
||||
@@ -1565,7 +1564,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
).T
|
||||
else:
|
||||
w = ops.awq_dequantize(
|
||||
w = awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from typing import Dict, Tuple
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,11 @@ import torch
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingBatchInfo:
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# Some shortcuts for backward compatibility.
|
||||
# They will be removed in new versions.
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
|
||||
@@ -187,6 +187,7 @@ class ServerArgs:
|
||||
n_share_experts_fusion: int = 0
|
||||
disable_shared_experts_fusion: bool = False
|
||||
disable_chunked_prefix_cache: bool = False
|
||||
disable_fast_image_processor: bool = False
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
@@ -198,9 +199,6 @@ class ServerArgs:
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
disaggregation_transfer_backend: str = "mooncake"
|
||||
|
||||
# multimodal
|
||||
disable_fast_image_processor: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Expert parallelism
|
||||
if self.enable_ep_moe:
|
||||
@@ -1136,6 +1134,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-fast-image-processor",
|
||||
action="store_true",
|
||||
help="Adopt base image processor instead of fast image processor.",
|
||||
)
|
||||
|
||||
# Server warmups
|
||||
parser.add_argument(
|
||||
@@ -1187,13 +1190,6 @@ class ServerArgs:
|
||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||
)
|
||||
|
||||
# Multimodal
|
||||
parser.add_argument(
|
||||
"--disable-fast-image-processor",
|
||||
action="store_true",
|
||||
help="Adopt base image processor instead of fast image processor.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
@@ -55,7 +55,6 @@ import torch.distributed
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import zmq
|
||||
from decord import VideoReader, cpu
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from PIL import Image
|
||||
@@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
||||
|
||||
|
||||
def encode_video(video_path, frame_count_limit=None):
|
||||
# Lazy import because decord is not available on some arm platforms.
|
||||
from decord import VideoReader, cpu
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
@@ -26,8 +26,8 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from sglang.srt.entrypoints.engine import Engine
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user