Clean up imports (#5467)
This commit is contained in:
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
try:
|
||||
from deep_gemm import (
|
||||
@@ -13,8 +14,6 @@ try:
|
||||
except ImportError:
|
||||
use_deep_gemm = False
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -37,21 +36,16 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
_buffer = None
|
||||
if _is_hip:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
@@ -740,20 +734,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
)
|
||||
|
||||
for expert in range(layer.num_experts_per_partition):
|
||||
if _is_cuda:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
else:
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user