Improve logging & add logit cap (#471)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
|
||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap
|
||||
|
||||
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
||||
|
||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
||||
|
||||
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
self.logit_cap,
|
||||
)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.max_extend_len,
|
||||
self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.other_kv_index,
|
||||
input_metadata.total_num_tokens,
|
||||
self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
Reference in New Issue
Block a user