model: add support for qwen3vl series (#16780)

* support qwen3vl series.

Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>

* bugfix: fix the arch check for qwen3vl-moe.

* use build_ffn

* optimize deepstack structure

* optimize deepstack feature saving

* Revert "optimize deepstack feature saving" for temporal fix

This reverts commit f321b9fdf13e59527408152e73b1071e19a87e71.

* code clean

* use fused qkv in clip

* clean up / rm is_deepstack_layers for simplification

* add test model

* move test model to "big" section

* fix imrope check

* remove trailing whitespace

* fix rope fail

* metal : add imrope support

* add imrope support for sycl

* vulkan: add imrope w/o check

* fix vulkan

* webgpu: add imrope w/o check

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* fix tensor mapping

---------

Co-authored-by: Thireus ☠ <Thireus@users.noreply.github.com>
Co-authored-by: yairpatch <yairpatch@users.noreply.github.com>
Co-authored-by: LETS-BEE <LETS-BEE@users.noreply.github.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
JJJYmmm
2025-10-30 23:19:14 +08:00
committed by GitHub
parent dcca0d3ab8
commit d261223d24
28 changed files with 1125 additions and 97 deletions

View File

@@ -5474,7 +5474,7 @@ static void ggml_rope_cache_init(
}
static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
@@ -5509,14 +5509,26 @@ static void ggml_mrope_cache_init(
}
float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
if (is_imrope) { // qwen3vl apply interleaved mrope
if (sector % 3 == 1 && sector < 3 * sections[1]) {
theta = theta_h;
} else if (sector % 3 == 2 && sector < 3 * sections[2]) {
theta = theta_w;
} else if (sector % 3 == 0 && sector < 3 * sections[0]) {
theta = theta_t;
} else {
theta = theta_e;
}
} else {
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
}
rope_yarn(
@@ -5589,6 +5601,7 @@ static void ggml_compute_forward_rope_f32(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -5627,7 +5640,7 @@ static void ggml_compute_forward_rope_f32(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
@@ -5775,6 +5788,7 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -5813,7 +5827,7 @@ static void ggml_compute_forward_rope_f16(
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}

View File

@@ -125,7 +125,7 @@ template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
@@ -152,17 +152,29 @@ static __global__ void rope_multi(
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -276,7 +288,7 @@ template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -287,11 +299,11 @@ static void rope_multi_cuda(
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}
@@ -369,6 +381,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -406,11 +419,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else {
GGML_ABORT("fatal error");
}

View File

@@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_neox) {
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
} else if (is_mrope && !is_vision) {
} else if ((is_mrope || is_imrope) && !is_vision) {
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
} else if (is_vision) {
@@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s", base);
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
return res;
}

View File

@@ -76,6 +76,7 @@
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
#define FC_MUL_MV 600
#define FC_MUL_MM 700
#define FC_ROPE 800
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8

View File

@@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
#endif
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
@@ -3889,14 +3891,26 @@ kernel void kernel_rope_multi(
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
theta_base = (float) pos[i2 + args.ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
theta_base = (float) pos[i2 + args.ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + args.ne02 * 3];
}
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
}
// end of mrope

View File

@@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections,
const sycl::nd_item<3> & item_ct1) {
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
// get index pos
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
if (i0 >= ne0) {
@@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
}
}
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
const float freq_scale, const float freq_base, const float ext_factor,
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
const mrope_sections sections, queue_ptr stream) {
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
GGML_ASSERT(ne0 % 2 == 0);
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
@@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
});
}
}
@@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
@@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
if (dst->src[0]->type == GGML_TYPE_F16) {
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, sections, main_stream);
freq_factors, sections, is_imrope, main_stream);
} else if (dst->src[0]->type == GGML_TYPE_F32) {
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
main_stream);
is_imrope, main_stream);
} else {
GGML_ABORT("Fatal error: Tensor type unsupported!");
}

View File

@@ -1056,6 +1056,7 @@ struct vk_op_rope_push_constants {
uint32_t s1;
uint32_t s2;
int32_t sections[4];
uint32_t is_imrope;
uint32_t is_back;
uint32_t set_rows_stride;
};
@@ -9927,6 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
}
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
@@ -9948,7 +9951,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
{ sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
}, dryrun);
}

View File

@@ -27,6 +27,7 @@ layout (push_constant) uniform parameter {
uint s1;
uint s2;
int sections[4];
uint is_imrope;
uint is_back;
uint set_rows_stride;
} p;

View File

@@ -32,17 +32,29 @@ void main() {
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
if (p.is_imrope != 0) {
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
} else {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
} else {
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

View File

@@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let is_neox = bool(params.mode & 2);
let is_mrope = bool(params.mode & 8);
let is_imrope = params.mode == 40;
let is_vision = params.mode == 24;
var i = gid.x * 2; // start index for this thread
@@ -248,24 +249,36 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let sec_w = params.sections1 + params.sections0;
let sec_e = params.sections2 + sec_w;
let sector = (i0 / 2) % sect_dims;
if (sector >= params.sections0 && sector < sec_w) {
theta_base_mult = 1;
if (is_vision) {
theta_scale_pwr = sector - params.sections0;
}
} else if (sector >= sec_w && sector < sec_e) {
theta_base_mult = 2;
if (is_vision) {
theta_scale_pwr = sector - sec_w;
}
} else if (sector >= sec_e) {
if (is_vision) {
theta_scale_pwr = sector - sec_e;
theta_scale_pwr = (i0 / 2) % sec_e;
}
theta_base_mult = 3;
} else if (is_vision) {
theta_scale_pwr = sector;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * params.sections1) {
theta_base_mult = 1;
} else if (sector % 3 == 2 && sector < 3 * params.sections2) {
theta_base_mult = 2;
} else if (sector % 3 == 0 && sector < 3 * params.sections0) {
theta_base_mult = 0;
} else {
theta_base_mult = 3;
}
} else {
if (sector >= params.sections0 && sector < sec_w) {
theta_base_mult = 1;
if (is_vision) {
theta_scale_pwr = sector - params.sections0;
}
} else if (sector >= sec_w && sector < sec_e) {
theta_base_mult = 2;
if (is_vision) {
theta_scale_pwr = sector - sec_w;
}
} else if (sector >= sec_e) {
if (is_vision) {
theta_scale_pwr = sector - sec_e;
theta_scale_pwr = (i0 / 2) % sec_e;
}
theta_base_mult = 3;
} else if (is_vision) {
theta_scale_pwr = sector;
}
}
}
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));