Refactor DeepGEMM integration (#7150)
This commit is contained in:
@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt import debug_utils
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
||||
from sglang.srt.layers.quantization.deep_gemm import (
|
||||
_ENABLE_JIT_DEEPGEMM,
|
||||
update_deep_gemm_config,
|
||||
from sglang.srt.layers.quantization import (
|
||||
deep_gemm_wrapper,
|
||||
monkey_patch_isinstance_for_vllm_base_layer,
|
||||
)
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
@@ -205,8 +205,8 @@ class ModelRunner:
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
# Update deep gemm configure
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
update_deep_gemm_config(gpu_id, server_args)
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||
|
||||
# If it is a draft model, tp_group can be different
|
||||
self.initialize(min_per_gpu_memory)
|
||||
|
||||
Reference in New Issue
Block a user