Improve type annotation and styles (#2926)

This commit is contained in:
Lianmin Zheng
2025-01-16 12:51:11 -08:00
committed by GitHub
parent a883f0790d
commit bc6915e3b9
7 changed files with 78 additions and 26 deletions

View File

@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -532,7 +533,7 @@ class ModelRunner:
)
else:
cell_size = (
self.model_config.get_num_kv_heads(self.tp_size)
self.model_config.get_num_kv_heads(get_attention_tp_size())
* self.model_config.head_dim
* self.model_config.num_hidden_layers
* 2
@@ -626,7 +627,7 @@ class ModelRunner:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
@@ -637,7 +638,7 @@ class ModelRunner:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,