[BugFix][310P][v0.18.0] Use CPU generator cache for sampling (#8624)
### What this PR does / why we need it? This PR introduces a caching mechanism for CPU-based `torch.Generator` objects in the `_random_sample_310p` function to optimize sampling performance. It includes unit tests for cache persistence and state recovery. Feedback highlights a critical bug where keying the cache by batch index instead of generator ID can break RNG reproducibility during request re-scheduling, and notes a potential memory leak in the global cache. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested via new unit tests in `tests/ut/_310p/sample/test_sampler_310.py` verifying cache logic and error handling. --------- Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
201
tests/ut/_310p/sample/test_sampler_310.py
Normal file
201
tests/ut/_310p/sample/test_sampler_310.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend._310p.sample import sampler as sampler_310p
|
||||
|
||||
|
||||
class _FakeRow:
|
||||
def __init__(self):
|
||||
self.generators = []
|
||||
|
||||
def exponential_(self, generator=None):
|
||||
self.generators.append(generator)
|
||||
return self
|
||||
|
||||
|
||||
class _FakeQ:
|
||||
def __init__(self, batch_size):
|
||||
self.shape = (batch_size, 4)
|
||||
self.default_exponential_called = False
|
||||
self.rows = {idx: _FakeRow() for idx in range(batch_size)}
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def npu(self):
|
||||
return self
|
||||
|
||||
def exponential_(self):
|
||||
self.default_exponential_called = True
|
||||
return self
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.rows[idx]
|
||||
|
||||
|
||||
class _FakeCPUGenerator:
|
||||
def __init__(self, device=None):
|
||||
self.device = device
|
||||
self.state = None
|
||||
self.seed = None
|
||||
|
||||
def set_state(self, state):
|
||||
self.state = state
|
||||
|
||||
def manual_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
|
||||
class TestSampler310pGeneratorCache(TestBase):
|
||||
@patch.object(sampler_310p.torch, "npu", create=True)
|
||||
@patch.object(sampler_310p.torch, "Generator")
|
||||
@patch.object(sampler_310p.torch, "empty_like")
|
||||
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
|
||||
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
|
||||
def test_random_sample_310p_reuse_cpu_generator_cache(
|
||||
self,
|
||||
_mock_npu_stream_switch,
|
||||
_mock_global_stream,
|
||||
mock_empty_like,
|
||||
mock_generator_ctor,
|
||||
mock_npu,
|
||||
):
|
||||
# Same source generator should reuse one cached CPU generator.
|
||||
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
|
||||
probs = MagicMock()
|
||||
probs.div_.return_value = probs
|
||||
probs.argmax.return_value = probs
|
||||
probs.view.return_value = torch.tensor([0])
|
||||
|
||||
fake_q_first = _FakeQ(batch_size=2)
|
||||
fake_q_second = _FakeQ(batch_size=2)
|
||||
|
||||
npu_stream = MagicMock()
|
||||
generator = MagicMock()
|
||||
generator.get_state.return_value = b"state"
|
||||
generator.initial_seed.return_value = 7
|
||||
generators = {1: generator}
|
||||
mock_empty_like.side_effect = [fake_q_first, fake_q_second]
|
||||
mock_generator_ctor.side_effect = _FakeCPUGenerator
|
||||
|
||||
mock_npu.current_stream.return_value = npu_stream
|
||||
sampler_310p._random_sample_310p(probs, generators)
|
||||
sampler_310p._random_sample_310p(probs, generators)
|
||||
|
||||
self.assertEqual(mock_generator_ctor.call_count, 1)
|
||||
self.assertIn(1, sampler_310p._CPU_GENERATOR_CACHE_310P)
|
||||
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[1]
|
||||
self.assertIs(fake_q_first.rows[1].generators[0], cached_cpu_generator)
|
||||
self.assertIs(fake_q_second.rows[1].generators[0], cached_cpu_generator)
|
||||
self.assertEqual(source_generator_id, id(generator))
|
||||
self.assertEqual(cached_cpu_generator.state, b"state")
|
||||
self.assertIsNone(cached_cpu_generator.seed)
|
||||
self.assertEqual(npu_stream.wait_stream.call_count, 2)
|
||||
|
||||
@patch.object(sampler_310p.torch, "npu", create=True)
|
||||
@patch.object(sampler_310p.torch, "Generator")
|
||||
@patch.object(sampler_310p.torch, "empty_like")
|
||||
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
|
||||
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
|
||||
def test_random_sample_310p_fallback_to_initial_seed_when_set_state_failed(
|
||||
self,
|
||||
_mock_npu_stream_switch,
|
||||
_mock_global_stream,
|
||||
mock_empty_like,
|
||||
mock_generator_ctor,
|
||||
mock_npu,
|
||||
):
|
||||
# If syncing generator state fails, fallback to initial seed.
|
||||
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
|
||||
probs = MagicMock()
|
||||
probs.div_.return_value = probs
|
||||
probs.argmax.return_value = probs
|
||||
probs.view.return_value = torch.tensor([1])
|
||||
|
||||
fake_q = _FakeQ(batch_size=1)
|
||||
npu_stream = MagicMock()
|
||||
generator = MagicMock()
|
||||
generator.get_state.side_effect = RuntimeError("state read failed")
|
||||
generator.initial_seed.return_value = 1234
|
||||
generators = {0: generator}
|
||||
|
||||
class _FailSetStateCPUGenerator(_FakeCPUGenerator):
|
||||
def set_state(self, state):
|
||||
raise RuntimeError("state set failed")
|
||||
mock_empty_like.return_value = fake_q
|
||||
mock_generator_ctor.side_effect = _FailSetStateCPUGenerator
|
||||
|
||||
mock_npu.current_stream.return_value = npu_stream
|
||||
sampler_310p._random_sample_310p(probs, generators)
|
||||
|
||||
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[0]
|
||||
self.assertEqual(source_generator_id, id(generator))
|
||||
self.assertEqual(cached_cpu_generator.seed, 1234)
|
||||
self.assertIs(fake_q.rows[0].generators[0], cached_cpu_generator)
|
||||
self.assertEqual(npu_stream.wait_stream.call_count, 1)
|
||||
|
||||
@patch.object(sampler_310p.torch, "npu", create=True)
|
||||
@patch.object(sampler_310p.torch, "Generator")
|
||||
@patch.object(sampler_310p.torch, "empty_like")
|
||||
@patch.object(sampler_310p, "global_stream", return_value=MagicMock())
|
||||
@patch.object(sampler_310p, "npu_stream_switch", return_value=nullcontext())
|
||||
def test_random_sample_310p_rebuild_cache_when_generator_identity_changes(
|
||||
self,
|
||||
_mock_npu_stream_switch,
|
||||
_mock_global_stream,
|
||||
mock_empty_like,
|
||||
mock_generator_ctor,
|
||||
mock_npu,
|
||||
):
|
||||
# A new source generator object should rebuild cache entry.
|
||||
sampler_310p._CPU_GENERATOR_CACHE_310P.clear()
|
||||
probs = MagicMock()
|
||||
probs.div_.return_value = probs
|
||||
probs.argmax.return_value = probs
|
||||
probs.view.return_value = torch.tensor([0])
|
||||
|
||||
fake_q_first = _FakeQ(batch_size=1)
|
||||
fake_q_second = _FakeQ(batch_size=1)
|
||||
npu_stream = MagicMock()
|
||||
|
||||
generator_first = MagicMock()
|
||||
generator_first.get_state.return_value = b"state-1"
|
||||
generator_first.initial_seed.return_value = 11
|
||||
|
||||
generator_second = MagicMock()
|
||||
generator_second.get_state.return_value = b"state-2"
|
||||
generator_second.initial_seed.return_value = 22
|
||||
mock_empty_like.side_effect = [fake_q_first, fake_q_second]
|
||||
mock_generator_ctor.side_effect = _FakeCPUGenerator
|
||||
|
||||
mock_npu.current_stream.return_value = npu_stream
|
||||
sampler_310p._random_sample_310p(probs, {0: generator_first})
|
||||
sampler_310p._random_sample_310p(probs, {0: generator_second})
|
||||
|
||||
self.assertEqual(mock_generator_ctor.call_count, 2)
|
||||
first_cpu_generator = fake_q_first.rows[0].generators[0]
|
||||
second_cpu_generator = fake_q_second.rows[0].generators[0]
|
||||
self.assertIsNot(first_cpu_generator, second_cpu_generator)
|
||||
self.assertEqual(first_cpu_generator.state, b"state-1")
|
||||
self.assertEqual(second_cpu_generator.state, b"state-2")
|
||||
cached_cpu_generator, source_generator_id = sampler_310p._CPU_GENERATOR_CACHE_310P[0]
|
||||
self.assertIs(cached_cpu_generator, second_cpu_generator)
|
||||
self.assertEqual(source_generator_id, id(generator_second))
|
||||
@@ -25,6 +25,8 @@ from vllm_ascend.sample.sampler import (
|
||||
)
|
||||
from vllm_ascend.utils import global_stream, npu_stream_switch
|
||||
|
||||
_CPU_GENERATOR_CACHE_310P: dict[int, tuple[torch.Generator, int]] = {}
|
||||
|
||||
|
||||
def _random_sample_310p(
|
||||
probs: torch.Tensor,
|
||||
@@ -38,7 +40,18 @@ def _random_sample_310p(
|
||||
q.exponential_()
|
||||
if generators:
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
cache_entry = _CPU_GENERATOR_CACHE_310P.get(i)
|
||||
if cache_entry is None or cache_entry[1] != id(generator):
|
||||
cpu_generator = torch.Generator(device="cpu")
|
||||
try:
|
||||
# Keep RNG stream consistent with the original generator.
|
||||
cpu_generator.set_state(generator.get_state())
|
||||
except Exception:
|
||||
cpu_generator.manual_seed(generator.initial_seed())
|
||||
cache_entry = (cpu_generator, id(generator))
|
||||
_CPU_GENERATOR_CACHE_310P[i] = cache_entry
|
||||
cpu_generator, _ = cache_entry
|
||||
q[i].exponential_(generator=cpu_generator)
|
||||
q = q.npu()
|
||||
torch.npu.current_stream().wait_stream(global_stream())
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
Reference in New Issue
Block a user