feat: remove vllm distributed (#2907)

Co-authored-by: Zhangyi <1109276519@qq.com>
This commit is contained in:
Yineng Zhang
2025-01-17 22:31:51 +08:00
committed by GitHub
parent f3e9b4894b
commit 5dc54f1a62
45 changed files with 111 additions and 102 deletions

View File

@@ -21,16 +21,17 @@ from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from vllm.distributed import (
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
)
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
@@ -295,12 +296,15 @@ class ModelRunner:
monkey_patch_vllm_gguf_config()
# Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state()
with self.memory_saver_adapter.region():
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
monkey_patch_vllm_parallel_state(reverse=True)
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None: