feat: remove vllm distributed (#2907)
Co-authored-by: Zhangyi <1109276519@qq.com>
This commit is contained in:
@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM, PretrainedConfig
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
from safetensors.torch import safe_open
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
local_model_path = self._prepare_weights(
|
||||
model_config.model_path, model_config.revision
|
||||
@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from safetensors.torch import save_file
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if pattern is None:
|
||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||
|
||||
Reference in New Issue
Block a user