Higher priority for user input of max_prefill_tokens & format (#540)

This commit is contained in:
Ying Sheng
2024-06-12 21:48:40 -07:00
committed by GitHub
parent 1374334d38
commit fb9296f0ed
50 changed files with 817 additions and 569 deletions

View File

@@ -1,6 +1,7 @@
"""Radix attention."""
import torch
import numpy as np
import torch
from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads