[Bugfix] Fix multi-instance serving OOM on single card (#7427)
### What this PR does / why we need it?
Fix https://github.com/vllm-project/vllm-ascend/issues/7308.
Subtracting `init_non_torch_memory` (maybe used by the first instance)
from the total `non_torch_memory` when calculating
`available_kv_cache_memory`.
Directly use `non_torch_memory_increase` (contained in
`non_kv_cache_memory`) to calculate `available_kv_cache_memory`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Launch tow vllm-ascend instances sequentially on single card.
```bash
# Launch first instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8100 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}' \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager
# Launch second instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8101 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}' \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager
```
**Before this PR:**
```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340388298034668 GiB
init_non_torch_memory: 0.3616676330566406 GiB
non_torch_memory_before_empty_cache: 0.3896217346191406 GiB
non_torch_memory_increase: 0.0279541015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.3616676330566406 GiB
------------------------------------------------------------------
# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2336344718933105 GiB
init_non_torch_memory: 18.37220001220703 GiB
non_torch_memory_before_empty_cache: 18.399906158447266 GiB
non_torch_memory_increase: 0.02754974365234375 GiB
non_torch_memory_cleared_by_empty_cache: 18.372356414794922 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: -1.32 GiB
```
**After this PR:**
```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340540885925293 GiB
init_non_torch_memory: 0.36182403564453125 GiB
non_torch_memory_before_empty_cache: 0.38979339599609375 GiB
non_torch_memory_increase: 0.0279693603515625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------
# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.233344554901123 GiB
init_non_torch_memory: 18.74309539794922 GiB
non_torch_memory_before_empty_cache: 18.770355224609375 GiB
non_torch_memory_increase: 0.02725982666015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: 17.05 GiB
```
- vLLM version: v0.17.0
- vLLM main:
4497431df6
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
This commit is contained in:
@@ -264,6 +264,7 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
||||
additional_config={"layer_sharding": ["q_b_proj", "o_proj"]},
|
||||
reasoning_parser="deepseek_v3",
|
||||
tokenizer_mode="deepseek_v32",
|
||||
gpu_memory_utilization=0.8,
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(short_example_prompts, max_tokens)
|
||||
vllm_model.generate_greedy(long_example_prompts, max_tokens)
|
||||
@@ -292,6 +293,7 @@ def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
|
||||
additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True},
|
||||
reasoning_parser="deepseek_v3",
|
||||
tokenizer_mode="deepseek_v32",
|
||||
gpu_memory_utilization=0.8,
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(short_example_prompts, max_tokens)
|
||||
vllm_model.generate_greedy(long_example_prompts, max_tokens)
|
||||
|
||||
Reference in New Issue
Block a user