diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 3e24e7b9c..9fcbb794b 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -26,6 +26,11 @@ from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -38,10 +43,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import InputMetadata -MergedColumnParallelLinear = None -QKVParallelLinear = None -RowParallelLinear = None - class LlamaMLP(nn.Module): def __init__( @@ -295,23 +296,6 @@ class LlamaForCausalLM(nn.Module): cache_config: Optional[CacheConfig] = None, efficient_weight_load=False, ) -> None: - global MergedColumnParallelLinear - global QKVParallelLinear - global RowParallelLinear - - if efficient_weight_load: - from sglang.srt.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, - ) - else: - from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, - ) - super().__init__() self.config = config self.quant_config = quant_config