Fix llama for classification (#855)
This commit is contained in:
@@ -26,6 +26,11 @@ from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -38,10 +43,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
|
||||
MergedColumnParallelLinear = None
|
||||
QKVParallelLinear = None
|
||||
RowParallelLinear = None
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -295,23 +296,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
efficient_weight_load=False,
|
||||
) -> None:
|
||||
global MergedColumnParallelLinear
|
||||
global QKVParallelLinear
|
||||
global RowParallelLinear
|
||||
|
||||
if efficient_weight_load:
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
Reference in New Issue
Block a user