[bugfix] Fix bugs in _dumm_run and re-initialize kv-cache. (#3262)
### What this PR does / why we need it? Currently we run an extra profile_run with `num_tokens == self.mc2_tokens_capacity`. However, when setting `max_num_batched_tokens < self.mc2_tokens_capacity`, this will trigger an assertion error that requires num_tokens in `_dummy_run` to be smaller than `max_num_batched_tokens`. This PR skips this extra `profile_run` if `self.max_num_tokens <= self.mc2_tokens_capacity` so as to avoid this bug. This PR fixes a bug that `kernel_block_sizes` never equals to `[self.cache_config.block_size]`. `kernel_block_sizes` is type of List[List[int]], so the condition should be `kernel_block_sizes != [[self.cache_config.block_size]]`. This also helps to resolve a issue that cpu_offload_gb cannot be enabled. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -502,7 +502,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.is_pooling_model,
|
||||
self.vllm_config.model_config.logits_processors),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
kernel_block_sizes=None,
|
||||
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
|
||||
)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
@@ -2511,7 +2511,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# MC2 will consume additional NPU memory.
|
||||
# Therefore, we need to run the MC2 path once here to complete its initialization,
|
||||
# allowing vLLM to correctly estimate the maximum memory required.
|
||||
if self._select_moe_comm_method(
|
||||
if self.max_num_tokens > self.mc2_tokens_capacity and \
|
||||
self._select_moe_comm_method(
|
||||
self.mc2_tokens_capacity,
|
||||
with_prefill=True) == MoECommType.MC2:
|
||||
self._dummy_run(self.mc2_tokens_capacity, with_prefill=True)
|
||||
@@ -3140,7 +3141,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# of mamba block. In this case, BlockTable.block_size will never equal
|
||||
# to kernel_block_sizes[0]
|
||||
kernel_block_sizes.append([0])
|
||||
if kernel_block_sizes != [self.cache_config.block_size]:
|
||||
if kernel_block_sizes != [[self.cache_config.block_size]]:
|
||||
assert self.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user