#ifdef U32_DEQUANT_HELPERS
#define SRC0_TYPE u32

fn byte_of(v: u32, b: u32) -> u32 {
    return (v >> (b * 8u)) & 0xFFu;
}

fn sbyte_of(v: u32, b: u32) -> i32 {
    let raw = i32((v >> (b * 8u)) & 0xFFu);
    return select(raw, raw - 256, raw >= 128);
}
#endif

#ifdef VEC
#define VEC_SIZE 4u
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>

fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
    return f32(dot(SRC1_TYPE(src0_val), src1_val));
}
#endif

#ifdef SCALAR
#define VEC_SIZE 1u
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE

fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
    return f32(src0_val) * f32(src1_val);
}
#endif

#ifdef MUL_ACC_FLOAT
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let k_vec = params.k / VEC_SIZE;
    let src1_idx_base_vec = src1_idx_base / VEC_SIZE;

    // Each thread walks K, loads from the vector, and updates
    // a small block of output rows held in registers.
    for (var k = thread_id; k < k_vec; k += WG_SIZE) {
        let x = src1[src1_idx_base_vec + k];
        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k;
                acc[row] += inner_dot(src0[src0_idx], x);
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q1_0
#define BLOCK_SIZE 128
#define BLOCK_SIZE_BYTES 18
#define THREADS_PER_BLOCK 16
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % THREADS_PER_BLOCK;
    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
                var row_sum = 0.0;
                for (var bit = 0u; bit < 8u; bit++) {
                    let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
                    row_sum += w * x_block[bit];
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q4_0
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 18
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % 4;
    for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
            x_block[i] = f32(src1[x_base + i]);
            x_block[i + 4] = f32(src1[x_base + i + 16]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                var row_sum = 0.0;

                let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
                for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
                    let q_byte = get_byte(q_packed, byte_idx);
                    let q_lo = (f32(q_byte & 0xFu) - 8.0) * d;
                    let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d;
                    row_sum += q_lo * x_block[byte_idx];
                    row_sum += q_hi * x_block[byte_idx + 4u];
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 20
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % THREADS_PER_BLOCK;
    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
            x_block[i] = f32(src1[x_base + i]);
            x_block[i + 4] = f32(src1[x_base + i + 16]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let m = f32(load_f16_at_src0(block_byte_base + 2u));
                var row_sum = 0.0;

                let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block);
                for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
                    let q_byte = get_byte(q_packed, byte_idx);
                    let q_lo = f32(q_byte & 0xFu) * d + m;
                    let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m;
                    row_sum += q_lo * x_block[byte_idx];
                    row_sum += q_hi * x_block[byte_idx + 4u];
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q5_0
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 22
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % THREADS_PER_BLOCK;
    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
            x_block[i] = f32(src1[x_base + i]);
            x_block[i + 4] = f32(src1[x_base + i + 16]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let qh_packed = load_u32_at_src0(block_byte_base + 2u);
                let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block);
                let qh_shift = thread_within_block * 4u;
                var row_sum = 0.0;

                for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
                    let q_byte = get_byte(q_packed, byte_idx);
                    let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
                    let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
                    let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d;
                    let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d;
                    row_sum += q_lo * x_block[byte_idx];
                    row_sum += q_hi * x_block[byte_idx + 4u];
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q5_1
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 24
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % THREADS_PER_BLOCK;
    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) {
            x_block[i] = f32(src1[x_base + i]);
            x_block[i + 4] = f32(src1[x_base + i + 16]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let m = f32(load_f16_at_src0(block_byte_base + 2u));
                let qh_packed = load_u32_at_src0(block_byte_base + 4u);
                let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block);
                let qh_shift = thread_within_block * 4u;
                var row_sum = 0.0;

                for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
                    let q_byte = get_byte(q_packed, byte_idx);
                    let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u;
                    let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u;
                    let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m;
                    let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m;
                    row_sum += q_lo * x_block[byte_idx];
                    row_sum += q_hi * x_block[byte_idx + 4u];
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 34
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % THREADS_PER_BLOCK;
    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                var row_sum = 0.0;

                for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
                    let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx));
                    for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
                        let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d;
                        row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q8_1
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 36
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % THREADS_PER_BLOCK;
    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let m = f32(load_f16_at_src0(block_byte_base + 2u));
                var row_sum = 0.0;

                for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) {
                    let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx));
                    for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
                        let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m;
                        row_sum += q_val * x_block[packed_idx * 4u + byte_idx];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 84
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let lane = tid / 2u;
    let phase = tid % 2u;
    let iq = lane / 4u;
    let ir = lane % 4u;
    let is = ir / 2u;

    let y_offset = 128u * iq + 8u * ir + 4u * phase;
    let sc0_byte = 8u * iq + is;
    let sc2_byte = 8u * iq + is + 2u;
    let sc4_byte = 8u * iq + is + 4u;
    let sc6_byte = 8u * iq + is + 6u;
    let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 4u; i++) {
            x_block[i]       = f32(src1[x_base + i]);
            x_block[i + 4u]  = f32(src1[x_base + 32u + i]);
            x_block[i + 8u]  = f32(src1[x_base + 64u + i]);
            x_block[i + 12u] = f32(src1[x_base + 96u + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;

                let dall = f32(load_f16_at_src0(block_byte_base + 80u));
                let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0);

                let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u);
                let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u);
                let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u);
                let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u);

                let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte);
                let qs0 = q_u32 & 0xFFFFu;
                let qs1 = q_u32 >> 16u;

                var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0);
                var acc1 = vec4<f32>(0.0, 0.0, 0.0, 0.0);
                var acc2 = vec4<f32>(0.0, 0.0, 0.0, 0.0);

                sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3];
                sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7];
                sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11];
                sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15];

                acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u);
                acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u);
                acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu);
                acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u);
                acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u);
                acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u);
                acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u);
                acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u);

                acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) +
                                    (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 +
                                    (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 +
                                    (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0)
                          - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) +
                                    sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u));
            }
        }
    }

    return acc;
}
#endif


#ifdef MUL_ACC_Q3_K
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 110
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let lane = tid / 2u;
    let phase = tid % 2u;
    let ip = lane / 4u;
    let il = 2u * ((lane % 4u) / 2u);
    let ir = lane % 2u;
    let l0 = 8u * ir;

    let q_byte = 32u + 32u * ip + l0 + 16u * phase;
    let h_byte = l0 + 16u * phase;
    let y_offset = 128u * ip + 32u * il + l0 + 16u * phase;

    let s_shift1 = 4u * ip;
    let s_shift2 = s_shift1 + il;

    let v1 = select(64.0, 4.0, il == 0u);
    let v2 = 4.0 * v1;
    let shift = 2u * il;

    var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32;
    if (il == 0u) {
        qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u;
    } else {
        qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u;
    }

    let mm_idx = 2u * ip + il / 2u;
    var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32;
    switch (mm_idx) {
        case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; }
        case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; }
        case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; }
        default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; }
    }

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 8u; i++) {
            x_block[i] = f32(src1[x_base + i]);
            x_block[i + 8u] = f32(src1[x_base + 32u + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;

                let d = f32(load_f16_at_src0(block_byte_base + 108u));
                let a_base = 96u;
                let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u);
                let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u);
                let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u);
                let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u);

                var scales32 = a_4 | (a_5 << 16u);
                let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u;
                scales32 = a_il0 | (a_il1 << 16u);
                scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32;

                let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32);
                let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32);

                let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u);
                let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u);
                let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u);
                let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u);

                var s1 = 0.0; var s2 = 0.0; var s3 = 0.0;
                var s4 = 0.0; var s5 = 0.0; var s6 = 0.0;

                for (var l = 0u; l < 8u; l += 2u) {
                    let q_u32 = select(q_u32_0, q_u32_1, l >= 4u);
                    let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u);
                    let h_u32 = select(h_u32_0, h_u32_1, l >= 4u);
                    let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u);

                    s1 += x_block[l + 0u] * f32(qs & qm0);
                    s2 += x_block[l + 1u] * f32(qs & qm1);
                    s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) +
                          select(0.0, x_block[l + 1u], (hv & hm1) == 0u);
                    s4 += x_block[l + 8u] * f32(qs & qm2);
                    s5 += x_block[l + 9u] * f32(qs & qm3);
                    s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) +
                          select(0.0, x_block[l + 9u], (hv & hm3) == 0u);
                }

                let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1);
                let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2);
                acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift);
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 144
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let il = tid / 4u;
    let ir = tid % 4u;
    let im = il / 2u;
    let in = il % 2u;
    let l0 = 4u * (2u * ir + in);

    let y_offset = 64u * im + l0;
    let q_offset = 32u * im + l0;
    let sc0_byte = 4u + im * 2u;
    let sc2_byte = 4u + (im + 2u) * 2u;
    let sc4_byte = 4u + (im + 4u) * 2u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 4u; i++) {
            x_block[i]       = f32(src1[x_base + i]);
            x_block[i + 4u]  = f32(src1[x_base + 32u + i]);
            x_block[i + 8u]  = f32(src1[x_base + 128u + i]);
            x_block[i + 12u] = f32(src1[x_base + 160u + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;

                let d = f32(load_f16_at_src0(block_byte_base + 0u));
                let dmin = f32(load_f16_at_src0(block_byte_base + 2u));

                let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte);
                let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u);
                let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte);
                let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u);
                let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte);
                let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u);

                let sc16_0 = sc0 & 0x3F3Fu;
                let sc16_1 = sc2 & 0x3F3Fu;
                let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u);
                let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u);

                let scale0 = f32(sc16_0 & 0xFFu);
                let scale1 = f32((sc16_0 >> 8u) & 0xFFu);
                let min0 = f32(sc16_1 & 0xFFu);
                let min1 = f32((sc16_1 >> 8u) & 0xFFu);
                let scale2 = f32(sc16_2 & 0xFFu);
                let scale3 = f32((sc16_2 >> 8u) & 0xFFu);
                let min2 = f32(sc16_3 & 0xFFu);
                let min3 = f32((sc16_3 >> 8u) & 0xFFu);

                let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset);
                let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset);

                var dot = vec4<f32>(0.0, 0.0, 0.0, 0.0);
                var sumx = vec4<f32>(0.0, 0.0, 0.0, 0.0);
                for (var i = 0u; i < 4u; i++) {
                    let q1b = byte_of(q1_u32, i);
                    let q2b = byte_of(q2_u32, i);
                    dot[0] += x_block[i] * f32(q1b & 0x0Fu);
                    dot[1] += x_block[i + 4u] * f32(q1b >> 4u);
                    dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu);
                    dot[3] += x_block[i + 12u] * f32(q2b >> 4u);
                    sumx[0] += x_block[i];
                    sumx[1] += x_block[i + 4u];
                    sumx[2] += x_block[i + 8u];
                    sumx[3] += x_block[i + 12u];
                }

                acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3)
                          - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3);
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q5_K
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 176
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let il = tid / 4u;
    let ir = tid % 4u;
    let im = il / 2u;
    let in = il % 2u;
    let l0 = 4u * (2u * ir + in);

    let y_offset = 64u * im + l0;
    let q_offset = 48u + 32u * im + l0;
    let qh_offset = 16u + 8u * ir + 4u * in;
    let sc0_byte = 4u + im * 2u;
    let sc2_byte = 4u + (im + 2u) * 2u;
    let sc4_byte = 4u + (im + 4u) * 2u;

    let hm1 = 1u << (2u * im);
    let hm2 = hm1 << 1u;
    let hm3 = hm1 << 4u;
    let hm4 = hm2 << 4u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 4u; i++) {
            x_block[i]       = f32(src1[x_base + i]);
            x_block[i + 4u]  = f32(src1[x_base + 32u + i]);
            x_block[i + 8u]  = f32(src1[x_base + 128u + i]);
            x_block[i + 12u] = f32(src1[x_base + 160u + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;

                let d = f32(load_f16_at_src0(block_byte_base + 0u));
                let dmin = f32(load_f16_at_src0(block_byte_base + 2u));

                let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte);
                let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u);
                let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte);
                let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u);
                let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte);
                let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u);

                let sc16_0 = sc0 & 0x3F3Fu;
                let sc16_1 = sc2 & 0x3F3Fu;
                let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u);
                let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u);

                let f0 = f32(sc16_0 & 0xFFu);
                let f1 = f32((sc16_0 >> 8u) & 0xFFu);
                let m0 = f32(sc16_1 & 0xFFu);
                let m1 = f32((sc16_1 >> 8u) & 0xFFu);
                let f4 = f32(sc16_2 & 0xFFu);
                let f5 = f32((sc16_2 >> 8u) & 0xFFu);
                let m4 = f32(sc16_3 & 0xFFu);
                let m5 = f32((sc16_3 >> 8u) & 0xFFu);

                let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset);
                let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u);
                let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset);

                var vals = vec4<f32>(0.0, 0.0, 0.0, 0.0);
                var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0);
                for (var i = 0u; i < 4u; i++) {
                    let q1b = byte_of(q1_u32, i);
                    let q2b = byte_of(q2_u32, i);
                    let qhb = byte_of(qh_u32, i);

                    let yl0 = x_block[i];
                    let yl8 = x_block[i + 4u];
                    let yh0 = x_block[i + 8u];
                    let yh8 = x_block[i + 12u];

                    sumy[0] += yl0;
                    sumy[1] += yl8;
                    sumy[2] += yh0;
                    sumy[3] += yh8;

                    let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u));
                    let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u));
                    let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u));
                    let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u));

                    vals[0] += yl0 * q0;
                    vals[1] += yl8 * q1;
                    vals[2] += yh0 * q2;
                    vals[3] += yh8 * q3;
                }

                acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3])
                          - dmin * (sumy[0] * m0 + sumy[1] * m1 +
                                    sumy[2] * m4 + sumy[3] * m5);
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q6_K
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 210
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let ip = tid / 8u;
    let il = tid % 8u;
    let l0 = 4u * il;
    let is = 8u * ip + l0 / 16u;

    let y_offset = 128u * ip + l0;
    let q_offset_l = 64u * ip + l0;
    let q_offset_h = 32u * ip + l0;

    let num_blocks = params.k / BLOCK_SIZE;
    let sc_base_byte = 192u + (is & ~3u);
    let sc_byte_pos = is & 3u;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var l = 0u; l < 4u; l++) {
            x_block[l]       = f32(src1[x_base + l]);
            x_block[l + 4u]  = f32(src1[x_base + 32u + l]);
            x_block[l + 8u]  = f32(src1[x_base + 64u + l]);
            x_block[l + 12u] = f32(src1[x_base + 96u + l]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;

                let d = f32(load_f16_at_src0(block_byte_base + 208u));
                let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l);
                let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u);
                let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h);
                let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte);
                let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u);

                let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
                let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
                let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
                let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);

                var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);

                for (var l = 0u; l < 4u; l++) {
                    let q1b = byte_of(ql1_u32, l);
                    let q2b = byte_of(ql2_u32, l);
                    let qhb = byte_of(qh_u32, l);

                    let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
                    let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
                    let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32);
                    let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);

                    sums[0] += x_block[l] * dq0;
                    sums[1] += x_block[l + 4u] * dq1;
                    sums[2] += x_block[l + 8u] * dq2;
                    sums[3] += x_block[l + 12u] * dq3;
                }

                acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
                                 sums[2] * f32(sc4) + sums[3] * f32(sc6));
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ1_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 50
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let slot0   = half * 2u;
    let y_offset = sub_blk * 32u + slot0 * 8u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;

                let d     = f32(load_f16_at_src0(block_byte_base));
                let qh    = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu;
                let dl    = d * f32(2u * ((qh >> 12u) & 7u) + 1u);
                let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
                let qs_w  = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);

                var row_sum = 0.0;
                for (var ll = 0u; ll < 2u; ll++) {
                    let l       = slot0 + ll;
                    let qs_byte = get_byte(qs_w, l);
                    let ig      = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
                    let gw      = iq1_grid[ig / 16u];
                    let bit_base = (ig % 16u) * 2u;
                    for (var j = 0u; j < 8u; j++) {
                        let g  = (gw >> (bit_base + j * 2u)) & 3u;
                        let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
                        row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ1_M
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 56
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let slot0   = half * 2u;
    let y_offset = sub_blk * 32u + slot0 * 8u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;

                let sc_lo = load_u32_at_src0(block_byte_base + 48u);
                let sc_hi = load_u32_at_src0(block_byte_base + 52u);
                let sc0 = sc_lo & 0xFFFFu;
                let sc1 = (sc_lo >> 16u) & 0xFFFFu;
                let sc2 = sc_hi & 0xFFFFu;
                let sc3 = (sc_hi >> 16u) & 0xFFFFu;
                let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u);
                let d = f32(bitcast<vec2<f16>>(d_bits)[0]);

                let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u),
                                    select(sc0, sc1, sub_blk >= 2u),
                                    sub_blk < 4u);

                let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u);
                let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu;
                let qh_lo = qh & 0xFFu;
                let qh_hi = (qh >> 8u) & 0xFFu;

                var row_sum = 0.0;
                for (var ll = 0u; ll < 2u; ll++) {
                    let l = slot0 + ll;
                    let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u);
                    let sub_scale = (sc_u16 >> bit_off) & 0x7u;
                    let dl = d * f32(2u * sub_scale + 1u);
                    let qh_byte = select(qh_lo, qh_hi, l >= 2u);
                    let ll2 = l % 2u;
                    let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u);
                    let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u);
                    let ig = grid_idx * 8u;
                    let gw = iq1_grid[ig / 16u];
                    let bit_base = (ig % 16u) * 2u;
                    for (var j = 0u; j < 8u; j++) {
                        let g  = (gw >> (bit_base + j * 2u)) & 3u;
                        let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
                        row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ2_XXS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 66
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let slot0   = half * 2u;
    let y_offset = sub_blk * 32u + slot0 * 8u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
                let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
                let ls = aux_hi >> 28u;
                let db = d * (0.5 + f32(ls)) * 0.25;

                var row_sum = 0.0;
                for (var ll = 0u; ll < 2u; ll++) {
                    let l = slot0 + ll;
                    let grid_idx = (aux_lo >> (8u * l)) & 0xFFu;
                    let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu;
                    let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
                    let gw_lo = iq2xxs_grid[grid_idx * 2u];
                    let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u];
                    for (var j = 0u; j < 8u; j++) {
                        let gw = select(gw_hi, gw_lo, j < 4u);
                        let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
                        let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
                        row_sum += db * b * s * x_block[ll * 8u + j];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ2_XS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 74
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let slot0   = half * 2u;
    let y_offset = sub_blk * 32u + slot0 * 8u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
                let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
                let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
                let scales_byte = get_byte(scales_word, sub_blk % 4u);

                var row_sum = 0.0;
                for (var ll = 0u; ll < 2u; ll++) {
                    let l = slot0 + ll;
                    let qs_word = select(qs_hi, qs_lo, l < 2u);
                    let half2 = (l % 2u) * 16u;
                    let qs_val = (qs_word >> half2) & 0xFFFFu;
                    let grid_idx = qs_val & 0x1FFu;
                    let signs_idx = (qs_val >> 9u) & 0x7Fu;
                    let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
                    let db = d * (0.5 + f32(sub_scale)) * 0.25;
                    let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
                    let gw_lo = iq2xs_grid[grid_idx * 2u];
                    let gw_hi = iq2xs_grid[grid_idx * 2u + 1u];
                    for (var j = 0u; j < 8u; j++) {
                        let gw = select(gw_hi, gw_lo, j < 4u);
                        let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
                        let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
                        row_sum += db * b * s * x_block[ll * 8u + j];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ2_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 82
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let slot0   = half * 2u;
    let y_offset = sub_blk * 32u + slot0 * 8u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
                let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u);
                let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
                let qh_byte = get_byte(qh_word, sub_blk % 4u);
                let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u);
                let scales_byte = get_byte(sc_word, sub_blk % 4u);

                var row_sum = 0.0;
                for (var ll = 0u; ll < 2u; ll++) {
                    let l = slot0 + ll;
                    let qs_byte = get_byte(qs_w, l);
                    let sign_byte = get_byte(sg_w, l);
                    let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u);
                    let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
                    let db = d * (0.5 + f32(sub_scale)) * 0.25;
                    let gw_lo = iq2s_grid[grid_idx * 2u];
                    let gw_hi = iq2s_grid[grid_idx * 2u + 1u];
                    for (var j = 0u; j < 8u; j++) {
                        let gw = select(gw_hi, gw_lo, j < 4u);
                        let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
                        let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
                        row_sum += db * b * s * x_block[ll * 8u + j];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ3_XXS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 98
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let slot0   = half * 2u;
    let y_offset = sub_blk * 32u + slot0 * 8u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
                let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
                let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u);
                let ls = aux >> 28u;
                let db = d * (0.5 + f32(ls)) * 0.5;

                var row_sum = 0.0;
                for (var ll = 0u; ll < 2u; ll++) {
                    let l = slot0 + ll;
                    let qs_word = select(qs_hi, qs_lo, l < 2u);
                    let byte_pos = (l % 2u) * 2u;
                    let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
                    let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
                    let signs_idx = (aux >> (7u * l)) & 0x7Fu;
                    let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
                    let grid1 = iq3xxs_grid[grid_idx_0];
                    let grid2 = iq3xxs_grid[grid_idx_1];
                    for (var j = 0u; j < 4u; j++) {
                        let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
                        let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
                        let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
                        let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u);
                        row_sum += db * b1 * s1 * x_block[ll * 8u + j];
                        row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ3_S
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 110
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let slot0   = half * 2u;
    let y_offset = sub_blk * 32u + slot0 * 8u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
                let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
                let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
                let qh_byte = get_byte(qh_word, sub_blk % 4u);
                let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u);
                let sc_word = load_u32_at_src0(block_byte_base + 106u);
                let scales_byte = get_byte(sc_word, sub_blk / 2u);
                let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu;
                let db = d * (1.0 + 2.0 * f32(sub_scale));

                var row_sum = 0.0;
                for (var ll = 0u; ll < 2u; ll++) {
                    let l = slot0 + ll;
                    let qs_word = select(qs_hi, qs_lo, l < 2u);
                    let byte_pos = (l % 2u) * 2u;
                    let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
                    let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
                    let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u);
                    let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u);
                    let sign_byte = get_byte(sg_w, l);
                    let grid1 = iq3s_grid[grid_idx_1];
                    let grid2 = iq3s_grid[grid_idx_2];
                    for (var j = 0u; j < 4u; j++) {
                        let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
                        let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
                        let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
                        let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u);
                        row_sum += db * b1 * s1 * x_block[ll * 8u + j];
                        row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
                    }
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ4_NL
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 18
#define THREADS_PER_BLOCK 4
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;
    let thread_within_block = thread_id % THREADS_PER_BLOCK;
    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u;
        var x_block: array<f32, ELEMS_PER_THREAD>;
        for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) {
            x_block[i] = f32(src1[x_base + i]);
            x_block[i + 4u] = f32(src1[x_base + i + 16u]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                var row_sum = 0.0;

                let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
                for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
                    let q_byte = get_byte(q_packed, byte_idx);
                    let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d;
                    let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d;
                    row_sum += q_lo * x_block[byte_idx];
                    row_sum += q_hi * x_block[byte_idx + 4u];
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_IQ4_XS
#define BLOCK_SIZE 256
#define BLOCK_SIZE_BYTES 136
#define THREADS_PER_BLOCK 16
fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;
    let block_group = thread_id / THREADS_PER_BLOCK;
    let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;

    let sub_blk = tid / 2u;
    let half    = tid % 2u;
    let y_offset = sub_blk * 32u + half * 16u;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = block_group; block < num_blocks; block += num_block_groups) {
        let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
        var x_block: array<f32, 16>;
        for (var i = 0u; i < 16u; i++) {
            x_block[i] = f32(src1[x_base + i]);
        }

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                let d = f32(load_f16_at_src0(block_byte_base));
                let scales_h = load_u16_at_src0(block_byte_base + 2u);
                let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
                let sl_byte = get_byte(scales_l_word, sub_blk / 2u);
                let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu;
                let sh_bits = (scales_h >> (2u * sub_blk)) & 3u;
                let ls = i32(sl | (sh_bits << 4u));
                let dl = d * f32(ls - 32);

                let qs_byte_off = 8u + sub_blk * 16u;
                let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off);
                let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u);
                let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u);
                let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u);

                var row_sum = 0.0;
                for (var i = 0u; i < 16u; i++) {
                    let q_word = select(
                        select(q_w0, q_w1, i >= 4u),
                        select(q_w2, q_w3, i >= 12u),
                        i >= 8u);
                    let q_byte = get_byte(q_word, i % 4u);
                    let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u);
                    row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i];
                }
                acc[row] += row_sum;
            }
        }
    }

    return acc;
}
#endif
