Move status check in the memory pool to CPU (#1557)
This commit is contained in:
@@ -409,8 +409,11 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda"
|
||||
max_num_reqs + 1,
|
||||
self.model_config.context_len + 4,
|
||||
device=device,
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -422,6 +425,7 @@ class ModelRunner:
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
@@ -430,6 +434,7 @@ class ModelRunner:
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=device,
|
||||
)
|
||||
logger.info(
|
||||
f"Memory pool end. "
|
||||
|
||||
Reference in New Issue
Block a user