Gemma Support (#256)

This commit is contained in:
Liangsheng Yin
2024-03-11 12:14:27 +08:00
committed by GitHub
parent 64fe311593
commit 89885b31ef
10 changed files with 428 additions and 55 deletions

View File

@@ -1,15 +1,9 @@
from typing import List
import torch
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
from torch import nn
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
class RadixAttention(nn.Module):
@@ -21,9 +15,9 @@ class RadixAttention(nn.Module):
self.head_dim = head_dim
self.layer_id = layer_id
from sglang.srt.managers.router.model_runner import global_model_mode
from sglang.srt.managers.router.model_runner import global_server_args
self.use_flashinfer = "flashinfer" in global_model_mode
self.use_flashinfer = "flashinfer" in global_server_args.model_mode
if self.use_flashinfer:
self.prefill_forward = self.prefill_forward_flashinfer