Fix llama for classification (#855)
This commit is contained in:
@@ -26,6 +26,11 @@ from vllm.config import CacheConfig
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@@ -38,10 +43,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||||
|
|
||||||
MergedColumnParallelLinear = None
|
|
||||||
QKVParallelLinear = None
|
|
||||||
RowParallelLinear = None
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -295,23 +296,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
efficient_weight_load=False,
|
efficient_weight_load=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
global MergedColumnParallelLinear
|
|
||||||
global QKVParallelLinear
|
|
||||||
global RowParallelLinear
|
|
||||||
|
|
||||||
if efficient_weight_load:
|
|
||||||
from sglang.srt.layers.linear import (
|
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
|
||||||
RowParallelLinear,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from vllm.model_executor.layers.linear import (
|
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
|
||||||
RowParallelLinear,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|||||||
Reference in New Issue
Block a user