Feat/support rerank (#6058)

This commit is contained in:
woodx
2025-06-17 01:50:01 +08:00
committed by GitHub
parent 91a066ec6a
commit e30ef368ab
20 changed files with 684 additions and 30 deletions

View File

@@ -20,6 +20,7 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.utils import resolve_obj_by_qualname
_is_cuda = is_cuda()
@@ -165,6 +167,23 @@ def get_act_fn(
return act_fn
def get_cross_encoder_activation_function(config: PretrainedConfig):
if (
hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None
):
function_name = config.sbert_ce_default_activation_function
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
return resolve_obj_by_qualname(function_name)()
else:
# adapt bge-reranker
return nn.Identity()
if not _is_cuda:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."