feat: remove vllm distributed (#2907)

Co-authored-by: Zhangyi <1109276519@qq.com>
This commit is contained in:
Yineng Zhang
2025-01-17 22:31:51 +08:00
committed by GitHub
parent f3e9b4894b
commit 5dc54f1a62
45 changed files with 111 additions and 102 deletions

View File

@@ -25,13 +25,13 @@ from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import ( from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs

View File

@@ -1,5 +1,6 @@
import torch import torch
from vllm.distributed import GroupCoordinator, get_tp_group
from sglang.srt.distributed import GroupCoordinator, get_tp_group
_ATTN_TP_GROUP = None _ATTN_TP_GROUP = None
_ATTN_TP_RANK = None _ATTN_TP_RANK = None

View File

@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (
from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@@ -15,7 +16,6 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
PackedColumnParameter, PackedColumnParameter,

View File

@@ -20,11 +20,11 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from torch import nn from torch import nn
from vllm.distributed import (
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,

View File

@@ -4,13 +4,13 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module from torch.nn import Module
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton, grouped_gemm_triton,

View File

@@ -5,13 +5,13 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
from vllm.distributed import ( from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts

View File

@@ -6,7 +6,8 @@ from typing import Callable, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
__all__ = [ __all__ = [
"BasevLLMParameter", "BasevLLMParameter",

View File

@@ -8,7 +8,6 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
@@ -24,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,

View File

@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (
from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,

View File

@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
import tqdm import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache

View File

@@ -21,16 +21,17 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed import (
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
get_tp_group, get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce, set_custom_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
@@ -295,12 +296,15 @@ class ModelRunner:
monkey_patch_vllm_gguf_config() monkey_patch_vllm_gguf_config()
# Load the model # Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state()
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region():
self.model = get_model( self.model = get_model(
model_config=self.model_config, model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
device_config=DeviceConfig(self.device), device_config=DeviceConfig(self.device),
) )
monkey_patch_vllm_parallel_state(reverse=True)
if self.server_args.kv_cache_dtype == "fp8_e4m3": if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None: if self.server_args.quantization_param_path is not None:

View File

@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import ( from sglang.srt.model_loader.utils import (
get_model_architecture, get_model_architecture,
@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
device_config: DeviceConfig, device_config: DeviceConfig,
) -> nn.Module: ) -> nn.Module:
from safetensors.torch import safe_open from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights( local_model_path = self._prepare_weights(
model_config.model_path, model_config.revision model_config.model_path, model_config.revision
@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
max_size: Optional[int] = None, max_size: Optional[int] = None,
) -> None: ) -> None:
from safetensors.torch import save_file from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
if pattern is None: if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN pattern = ShardedStateLoader.DEFAULT_PATTERN

View File

@@ -19,10 +19,10 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.utils import print_warning_once from sglang.srt.utils import print_warning_once

View File

@@ -24,10 +24,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@@ -35,6 +31,10 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -21,10 +21,10 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import ChatGLMConfig from sglang.srt.configs import ChatGLMConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -44,12 +44,12 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,

View File

@@ -19,14 +19,14 @@ from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import DbrxConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import DbrxConfig
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,

View File

@@ -21,13 +21,13 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -23,14 +23,14 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -21,9 +21,9 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -20,8 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -22,10 +22,11 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
# from sglang.srt.layers.activation import get_act_fn # from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,

View File

@@ -21,8 +21,8 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,

View File

@@ -22,9 +22,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import GraniteConfig from transformers import GraniteConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -22,12 +22,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -19,9 +19,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -22,13 +22,13 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -18,9 +18,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
@@ -28,6 +27,7 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -21,12 +21,12 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,

View File

@@ -23,13 +23,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,

View File

@@ -8,14 +8,14 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama import transformers.models.mllama.configuration_mllama as config_mllama
import vllm.distributed.parallel_state as ps
from torch import nn from torch import nn
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import ( from transformers.models.mllama.modeling_mllama import (
_prepare_aspect_ratio_attention_mask, _prepare_aspect_ratio_attention_mask,
) )
from vllm.distributed import get_tensor_model_parallel_world_size
import sglang.srt.distributed.parallel_state as ps
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -20,9 +20,9 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import OlmoConfig from transformers import OlmoConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,

View File

@@ -21,15 +21,15 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -23,10 +23,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@@ -35,6 +31,10 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput

View File

@@ -5,9 +5,9 @@ import torch
from torch import nn from torch import nn
from transformers import Phi3Config from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,

View File

@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -22,12 +22,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (

View File

@@ -30,12 +30,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import utils as dist_utils
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd, context_attention_fwd,

View File

@@ -24,9 +24,9 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,

View File

@@ -47,12 +47,12 @@ import torch
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput

View File

@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
@@ -31,6 +30,7 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention

View File

@@ -18,11 +18,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
@@ -33,6 +28,11 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig