diff --git a/docs/source/assets/cp/blocktable.png b/docs/source/assets/cp/blocktable.png index 6d73e730..86fc229a 100644 Binary files a/docs/source/assets/cp/blocktable.png and b/docs/source/assets/cp/blocktable.png differ diff --git a/docs/source/assets/cp/device_world.png b/docs/source/assets/cp/device_world.png new file mode 100644 index 00000000..2f798ec0 Binary files /dev/null and b/docs/source/assets/cp/device_world.png differ diff --git a/docs/source/assets/cp/pcp-prefill.png b/docs/source/assets/cp/pcp-prefill.png index 622a073a..a717d8d7 100644 Binary files a/docs/source/assets/cp/pcp-prefill.png and b/docs/source/assets/cp/pcp-prefill.png differ diff --git a/docs/source/developer_guide/feature_guide/context_parallel.md b/docs/source/developer_guide/feature_guide/context_parallel.md index 75d7900b..0de85507 100644 --- a/docs/source/developer_guide/feature_guide/context_parallel.md +++ b/docs/source/developer_guide/feature_guide/context_parallel.md @@ -25,20 +25,29 @@ Please refer to the [context parallel user guide](../../user_guide/feature_guide ## How It Works? +### Device Distribution + +We introduce new communication domains for PCP and reuse TP for DCP, and this is the new layout of devices for PCP2, DCP2, and TP4. +![device_world](../../assets/cp/device_world.png) + ### Block Table -CP performs sequence sharding on the KV cache storage. To facilitate efficient storage and access, tokens are stored in an interleaved manner across devices, with the interleaving granularity determined by `cp_kv_cache_interleave_size`. +CP performs sequence sharding on the KV cache storage. To facilitate efficient storage and access, tokens are stored in an interleaved manner across devices, with the interleaving granularity determined by `cp_kv_cache_interleave_size`, whose default value is `cp_kv_cache_interleave_size=1`, a.k.a. 'token interleave'. -As illustrated, a virtual block is defined in the block table, where blocks within the same CP device group form a virtual block. The virtual block size is `virtual_block_size = block_size * pcp_size * dcp_size`. +Given that PCP and DCP behave similarly for KV cache sharding, we refer to them collectively as CP. Specifically, `cp_size = pcp_size * dcp_size`, and `cp_rank = pcp_rank * dcp_size + dcp_rank`. -For any token `x`, its (virtual) block index is `x // virtual_block_size`, and the offset within the virtual block is `x % virtual_block_size`. The local block index is `offset_within_virtual_block // cp_kv_cache_interleave_size`, and the device number is `local_block_index % (pcp_size * dcp_size)`. The offset within the local block is `(local_block_index // (pcp_size * dcp_size)) * cp_kv_cache_interleave_size + offset_within_virtual_block % cp_kv_cache_interleave_size`. +As illustrated, a virtual block is defined in the block table, where blocks within the same CP device group form a virtual block. The virtual block size is `virtual_block_size = block_size * cp_size`. + +For any token `x`, referencing the folloing figure, its (virtual) block index is `x // virtual_block_size`, and the offset within the virtual block is `offset_within_virtual_block = x % virtual_block_size`. +The local block index is `local_block_index = offset_within_virtual_block // cp_kv_cache_interleave_size`, and the device number is `target_rank = local_block_index % cp_size`. +The offset within the local block is `(local_block_index // cp_size) * cp_kv_cache_interleave_size + offset_within_virtual_block % cp_kv_cache_interleave_size`. + +![BlockTable](../../assets/cp/blocktable.png) Based on the logic above, the `slot_mapping` calculation process is adjusted, and the `slot_mapping` values on each device are modified to ensure the KV cache is sharded along the sequence dimension and stored across different devices as expected. The current implementation requires that `block_size % cp_kv_cache_interleave_size == 0`. -![BlockTable](../../assets/cp/blocktable.png) - ### Decode Context Parallel (DCP) As mentioned above, the primary function of DCP is to shard the KV cache along the sequence dimension for storage. Its impact lies in the logic of the decode and chunked prefill phases.