Refactor DeepGEMM integration (#7150)

This commit is contained in:
fzyzcjy
2025-06-14 11:41:03 +08:00
committed by GitHub
parent 8b8f2e7463
commit b4c41f7276
12 changed files with 207 additions and 147 deletions

View File

@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt import debug_utils
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
from sglang.srt.layers.quantization.deep_gemm import (
_ENABLE_JIT_DEEPGEMM,
update_deep_gemm_config,
from sglang.srt.layers.quantization import (
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
)
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
@@ -205,8 +205,8 @@ class ModelRunner:
min_per_gpu_memory = self.init_torch_distributed()
# Update deep gemm configure
if _ENABLE_JIT_DEEPGEMM:
update_deep_gemm_config(gpu_id, server_args)
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
# If it is a draft model, tp_group can be different
self.initialize(min_per_gpu_memory)