Feat/support rerank (#6058)
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.utils import is_cuda, set_weight_attrs
|
||||
from sglang.utils import resolve_obj_by_qualname
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
@@ -165,6 +167,23 @@ def get_act_fn(
|
||||
return act_fn
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
if (
|
||||
hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None
|
||||
):
|
||||
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
assert function_name.startswith("torch.nn.modules."), (
|
||||
"Loading of activation functions is restricted to "
|
||||
"torch.nn.modules for security reasons"
|
||||
)
|
||||
return resolve_obj_by_qualname(function_name)()
|
||||
else:
|
||||
# adapt bge-reranker
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
if not _is_cuda:
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
|
||||
Reference in New Issue
Block a user