Improve logging & add logit cap (#471)

This commit is contained in:
Lianmin Zheng
2024-05-24 03:48:53 -07:00
committed by GitHub
parent 44c998fcb5
commit 2cea6146d8
12 changed files with 106 additions and 24 deletions

View File

@@ -1,4 +1,5 @@
import torch
import numpy as np
from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.layer_id = layer_id
self.logit_cap = logit_cap
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.router.model_runner import global_server_args_dict
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
input_metadata.start_loc,
input_metadata.seq_lens,
input_metadata.max_seq_len,
self.logit_cap,
)
self.store_kv_cache(k, v, input_metadata)
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
input_metadata.extend_seq_lens,
input_metadata.max_seq_len,
input_metadata.max_extend_len,
self.logit_cap,
)
return o
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
input_metadata.max_seq_len,
input_metadata.other_kv_index,
input_metadata.total_num_tokens,
self.logit_cap,
)
return o