Improve type annotation and styles (#2926)
This commit is contained in:
@@ -37,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
get_attention_tp_size,
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -532,7 +533,7 @@ class ModelRunner:
|
||||
)
|
||||
else:
|
||||
cell_size = (
|
||||
self.model_config.get_num_kv_heads(self.tp_size)
|
||||
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
||||
* self.model_config.head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
@@ -626,7 +627,7 @@ class ModelRunner:
|
||||
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
@@ -637,7 +638,7 @@ class ModelRunner:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=self.device,
|
||||
|
||||
Reference in New Issue
Block a user