feat: remove vllm distributed (#2907)

Co-authored-by: Zhangyi <1109276519@qq.com>
This commit is contained in:
Yineng Zhang
2025-01-17 22:31:51 +08:00
committed by GitHub
parent f3e9b4894b
commit 5dc54f1a62
45 changed files with 111 additions and 102 deletions

View File

@@ -25,13 +25,13 @@ from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import (
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs

View File

@@ -1,5 +1,6 @@
import torch
from vllm.distributed import GroupCoordinator, get_tp_group
from sglang.srt.distributed import GroupCoordinator, get_tp_group
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None

View File

@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -15,7 +16,6 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.parameter import (
BasevLLMParameter,
PackedColumnParameter,

View File

@@ -20,11 +20,11 @@ import torch
import triton
import triton.language as tl
from torch import nn
from vllm.distributed import (
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,

View File

@@ -4,13 +4,13 @@ from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import Module
from vllm import _custom_ops as ops
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,

View File

@@ -5,13 +5,13 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
from vllm.distributed import (
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts

View File

@@ -6,7 +6,8 @@ from typing import Callable, Optional, Union
import torch
from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
__all__ = [
"BasevLLMParameter",

View File

@@ -8,7 +8,6 @@ import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
@@ -24,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale,
)
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,

View File

@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,

View File

@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
import torch
import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache

View File

@@ -21,16 +21,17 @@ from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from vllm.distributed import (
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
)
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
@@ -295,12 +296,15 @@ class ModelRunner:
monkey_patch_vllm_gguf_config()
# Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state()
with self.memory_saver_adapter.region():
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
monkey_patch_vllm_parallel_state(reverse=True)
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:

View File

@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import (
get_model_architecture,
@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
device_config: DeviceConfig,
) -> nn.Module:
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(
model_config.model_path, model_config.revision
@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN

View File

@@ -19,10 +19,10 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.utils import print_warning_once

View File

@@ -24,10 +24,6 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -35,6 +31,10 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -21,10 +21,10 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import ChatGLMConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -44,12 +44,12 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,

View File

@@ -19,14 +19,14 @@ from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import DbrxConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import DbrxConfig
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,

View File

@@ -21,13 +21,13 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -23,14 +23,14 @@ import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -21,9 +21,9 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -20,8 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.linear import (

View File

@@ -22,10 +22,11 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import GPT2Config
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
# from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,

View File

@@ -21,8 +21,8 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import (
ColumnParallelLinear,

View File

@@ -22,9 +22,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import GraniteConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -22,12 +22,12 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -19,9 +19,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -22,13 +22,13 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -18,9 +18,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
@@ -28,6 +27,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor

View File

@@ -21,12 +21,12 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,

View File

@@ -23,13 +23,13 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,

View File

@@ -8,14 +8,14 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama
import vllm.distributed.parallel_state as ps
from torch import nn
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import (
_prepare_aspect_ratio_attention_mask,
)
from vllm.distributed import get_tensor_model_parallel_world_size
import sglang.srt.distributed.parallel_state as ps
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -20,9 +20,9 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import OlmoConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,

View File

@@ -21,15 +21,15 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -23,10 +23,6 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -35,6 +31,10 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput

View File

@@ -5,9 +5,9 @@ import torch
from torch import nn
from transformers import Phi3Config
from transformers.configuration_utils import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,

View File

@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -22,12 +22,12 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (

View File

@@ -30,12 +30,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import utils as dist_utils
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,

View File

@@ -24,9 +24,9 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,

View File

@@ -47,12 +47,12 @@ import torch
from torch import nn
from torch.nn.parameter import Parameter
from transformers import LlamaConfig
from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput

View File

@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -31,6 +30,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention

View File

@@ -18,11 +18,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -33,6 +28,11 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig