[BugFix][310P][v0.18.0] Use CPU generator cache for sampling (#8624)

### What this PR does / why we need it?
This PR introduces a caching mechanism for CPU-based `torch.Generator`
objects in the `_random_sample_310p` function to optimize sampling
performance. It includes unit tests for cache persistence and state
recovery. Feedback highlights a critical bug where keying the cache by
batch index instead of generator ID can break RNG reproducibility during
request re-scheduling, and notes a potential memory leak in the global
cache.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Tested via new unit tests in `tests/ut/_310p/sample/test_sampler_310.py`
verifying cache logic and error handling.

---------

Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
csoulnd
2026-04-24 09:34:14 +08:00
committed by GitHub
parent 00ddacf4e7
commit 97dbcaf919
2 changed files with 215 additions and 1 deletions

View File

@@ -0,0 +1,201 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend._310p.sample import sampler as sampler_310p
class _FakeRow:
def __init__(self):
self.generators = []
def exponential_(self, generator=None):
self.generators.append(generator)
return self
class _FakeQ:
def __init__(self, batch_size):
self.shape = (batch_size, 4)
self.default_exponential_called = False
self.rows = {idx: _FakeRow() for idx in range(batch_size)}
def cpu(self):
return self
def npu(self):
return self
def exponential_(self):
self.default_exponential_called = True
return self
def __getitem__(self, idx):
return self.rows[idx]
class _FakeCPUGenerator:
def __init__(self, device=None):
self.device = device
self.state = None
self.seed = None
def set_state(self, state):
self.state = state
def manual_seed(self, seed):
self.seed = seed
class TestSampler310pGeneratorCache(TestBase):
@patch.object(sampler_310p.torch, "npu", create=True)
@patch.object(sampler_310p.torch, "Generator")
@patch.object(sampler_310p.torch, "empty_like")
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
def test_random_sample_310p_reuse_cpu_generator_cache(
self,
_mock_npu_stream_switch,
_mock_global_stream,
mock_empty_like,
mock_generator_ctor,
mock_npu,
):
# Same source generator should reuse one cached CPU generator.
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
probs = MagicMock()
probs.div_.return_value = probs
probs.argmax.return_value = probs
probs.view.return_value = torch.tensor([0])
fake_q_first = _FakeQ(batch_size=2)
fake_q_second = _FakeQ(batch_size=2)
npu_stream = MagicMock()
generator = MagicMock()
generator.get_state.return_value = b"state"
generator.initial_seed.return_value = 7
generators = {1: generator}
mock_empty_like.side_effect = [fake_q_first, fake_q_second]
mock_generator_ctor.side_effect = _FakeCPUGenerator
mock_npu.current_stream.return_value = npu_stream
sampler_310p._random_sample_310p(probs, generators)
sampler_310p._random_sample_310p(probs, generators)
self.assertEqual(mock_generator_ctor.call_count, 1)
self.assertIn(1, sampler_310p._CPU_GENERATOR_CACHE_310P)
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[1]
self.assertIs(fake_q_first.rows[1].generators[0], cached_cpu_generator)
self.assertIs(fake_q_second.rows[1].generators[0], cached_cpu_generator)
self.assertEqual(source_generator_id, id(generator))
self.assertEqual(cached_cpu_generator.state, b"state")
self.assertIsNone(cached_cpu_generator.seed)
self.assertEqual(npu_stream.wait_stream.call_count, 2)
@patch.object(sampler_310p.torch, "npu", create=True)
@patch.object(sampler_310p.torch, "Generator")
@patch.object(sampler_310p.torch, "empty_like")
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
def test_random_sample_310p_fallback_to_initial_seed_when_set_state_failed(
self,
_mock_npu_stream_switch,
_mock_global_stream,
mock_empty_like,
mock_generator_ctor,
mock_npu,
):
# If syncing generator state fails, fallback to initial seed.
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
probs = MagicMock()
probs.div_.return_value = probs
probs.argmax.return_value = probs
probs.view.return_value = torch.tensor([1])
fake_q = _FakeQ(batch_size=1)
npu_stream = MagicMock()
generator = MagicMock()
generator.get_state.side_effect = RuntimeError("state read failed")
generator.initial_seed.return_value = 1234
generators = {0: generator}
class _FailSetStateCPUGenerator(_FakeCPUGenerator):
def set_state(self, state):
raise RuntimeError("state set failed")
mock_empty_like.return_value = fake_q
mock_generator_ctor.side_effect = _FailSetStateCPUGenerator
mock_npu.current_stream.return_value = npu_stream
sampler_310p._random_sample_310p(probs, generators)
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[0]
self.assertEqual(source_generator_id, id(generator))
self.assertEqual(cached_cpu_generator.seed, 1234)
self.assertIs(fake_q.rows[0].generators[0], cached_cpu_generator)
self.assertEqual(npu_stream.wait_stream.call_count, 1)
@patch.object(sampler_310p.torch, "npu", create=True)
@patch.object(sampler_310p.torch, "Generator")
@patch.object(sampler_310p.torch, "empty_like")
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
def test_random_sample_310p_rebuild_cache_when_generator_identity_changes(
self,
_mock_npu_stream_switch,
_mock_global_stream,
mock_empty_like,
mock_generator_ctor,
mock_npu,
):
# A new source generator object should rebuild cache entry.
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
probs = MagicMock()
probs.div_.return_value = probs
probs.argmax.return_value = probs
probs.view.return_value = torch.tensor([0])
fake_q_first = _FakeQ(batch_size=1)
fake_q_second = _FakeQ(batch_size=1)
npu_stream = MagicMock()
generator_first = MagicMock()
generator_first.get_state.return_value = b"state-1"
generator_first.initial_seed.return_value = 11
generator_second = MagicMock()
generator_second.get_state.return_value = b"state-2"
generator_second.initial_seed.return_value = 22
mock_empty_like.side_effect = [fake_q_first, fake_q_second]
mock_generator_ctor.side_effect = _FakeCPUGenerator
mock_npu.current_stream.return_value = npu_stream
sampler_310p._random_sample_310p(probs, {0: generator_first})
sampler_310p._random_sample_310p(probs, {0: generator_second})
self.assertEqual(mock_generator_ctor.call_count, 2)
first_cpu_generator = fake_q_first.rows[0].generators[0]
second_cpu_generator = fake_q_second.rows[0].generators[0]
self.assertIsNot(first_cpu_generator, second_cpu_generator)
self.assertEqual(first_cpu_generator.state, b"state-1")
self.assertEqual(second_cpu_generator.state, b"state-2")
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[0]
self.assertIs(cached_cpu_generator, second_cpu_generator)
self.assertEqual(source_generator_id, id(generator_second))

View File

@@ -25,6 +25,8 @@ from vllm_ascend.sample.sampler import (
) )
from vllm_ascend.utils import global_stream, npu_stream_switch from vllm_ascend.utils import global_stream, npu_stream_switch
_CPU_GENERATOR_CACHE_310P: dict[int, tuple[torch.Generator, int]] = {}
def _random_sample_310p( def _random_sample_310p(
probs: torch.Tensor, probs: torch.Tensor,
@@ -38,7 +40,18 @@ def _random_sample_310p(
q.exponential_() q.exponential_()
if generators: if generators:
for i, generator in generators.items(): for i, generator in generators.items():
q[i].exponential_(generator=generator) cache_entry = _CPU_GENERATOR_CACHE_310P.get(i)
if cache_entry is None or cache_entry[1] != id(generator):
cpu_generator = torch.Generator(device="cpu")
try:
# Keep RNG stream consistent with the original generator.
cpu_generator.set_state(generator.get_state())
except Exception:
cpu_generator.manual_seed(generator.initial_seed())
cache_entry = (cpu_generator, id(generator))
_CPU_GENERATOR_CACHE_310P[i] = cache_entry
cpu_generator, _ = cache_entry
q[i].exponential_(generator=cpu_generator)
q = q.npu() q = q.npu()
torch.npu.current_stream().wait_stream(global_stream()) torch.npu.current_stream().wait_stream(global_stream())
return probs.div_(q).argmax(dim=-1).view(-1) return probs.div_(q).argmax(dim=-1).view(-1)