enable f16;

#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"


#ifdef FLOAT
const BLOCK_SIZE = 1u;

#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
const BLOCK_SIZE = 32u;

#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
const BLOCK_SIZE = 256u;
#endif

#ifdef FLOAT
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
}
#endif

#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var sum: f32 = 0.0;
    for (var j: u32 = 0; j < 4; j++) {
        let q_byte_offset = block_byte_base + 2 + j * 4;
        let q_packed = load_u32_at_src0(q_byte_offset);
        for (var k: u32 = 0; k < 4; k++) {
            let q_byte = get_byte(q_packed, k);
            let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
            let q_lo = (f32(q_byte & 0xF) - 8.0f) * d;
            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
            sum += q_lo * f32(src1[src1_offset]);
            sum += q_hi * f32(src1[src1_offset + 16]);
        }
    }
    return sum;
}
#endif

#ifdef Q4_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_q4_1 = src0[src0_idx_base + offset];
    let d = f32(block_q4_1.d);
    let m = f32(block_q4_1.m);
    var sum: f32 = 0.0;
    for (var j: u32 = 0; j < 4; j++) {
        let q_packed = block_q4_1.qs[j];
        for (var k: u32 = 0; k < 4; k++) {
            let q_byte = get_byte(q_packed, k);
            let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
            let q_lo = f32(q_byte & 0xF) * d + m;
            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
            sum += q_lo * f32(src1[src1_offset]);
            sum += q_hi * f32(src1[src1_offset + 16]);
        }
    }
    return sum;
}
#endif

#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var sum: f32 = 0.0;
    let qh_packed = load_u32_at_src0(block_byte_base + 2);
    for (var j: u32 = 0; j < 4; j++) {
        let q_byte_offset = block_byte_base + 6 + j * 4;
        let q_packed = load_u32_at_src0(q_byte_offset);
        for (var k: u32 = 0; k < 4; k++) {
            let q_byte = get_byte(q_packed, k);
            let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
            let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
            let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
            let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
            sum += q_lo * f32(src1[src1_offset]);
            sum += q_hi * f32(src1[src1_offset + 16]);
        }
    }
    return sum;
}
#endif

#ifdef Q5_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_q5_1 = src0[src0_idx_base + offset];
    let d = f32(block_q5_1.d);
    let m = f32(block_q5_1.m);
    var sum: f32 = 0.0;
    for (var j: u32 = 0; j < 4; j++) {
        let q_packed = block_q5_1.qs[j];
        for (var k: u32 = 0; k < 4; k++) {
            let q_byte = get_byte(q_packed, k);
            let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
            let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
            let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
            let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
            sum += q_lo * f32(src1[src1_offset]);
            sum += q_hi * f32(src1[src1_offset + 16]);
        }
    }
    return sum;
}
#endif

#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var sum: f32 = 0.0;
    for (var j: u32 = 0; j < 8; j++) {
        let q_byte_offset = block_byte_base + 2 + j * 4;
        let q_packed = load_u32_at_src0(q_byte_offset);
        for (var k: u32 = 0u; k < 4u; k++) {
            let q_byte = get_byte_i32(q_packed, k);
            let q_val = f32(q_byte) * d;
            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
            sum += q_val * f32(src1[src1_offset]);
        }
    }
    return sum;
}
#endif

#ifdef Q8_1
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_q8_1 = src0[src0_idx_base + offset];
    let d = f32(block_q8_1.d);
    let m = f32(block_q8_1.m);
    var sum: f32 = 0.0;
    for (var j: u32 = 0; j < 8; j++) {
        let q_packed = block_q8_1.qs[j];
        for (var k: u32 = 0; k < 4; k++) {
            let q_byte = get_byte_i32(q_packed, k);
            let q_val = f32(q_byte) * d + m;
            let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
            sum += q_val * f32(src1[src1_offset]);
        }
    }
    return sum;
}
#endif

#ifdef Q2_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block = src0[src0_idx_base + offset];
    let d = f32(block.d);
    let m = f32(block.dmin);
    var sum = 0.0;
    var src1_i = src1_idx_base + offset * 256;
    var is: u32 = 0;
    // 2 halves of the block (128 elements each)
    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
        // 4 groups (each group has 2 blocks of 16 elements)
        for (var shift: u32 = 0; shift < 8; shift += 2) {
            // 2 blocks
            for (var k: u32 = 0; k < 32; k += 16) {
                let sc = get_byte(block.scales[is / 4], is % 4);
                is++;
                let dl = d * f32(sc & 0xF);
                let ml = m * f32(sc >> 4);
                for (var l: u32 = 0u; l < 16; l++) {
                    let q_idx = q_b_idx + k + l;
                    let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
                    let qs_val = (q_byte >> shift) & 3;
                    sum += (f32(qs_val) * dl - ml) * src1[src1_i];
                    src1_i++;
                }
            }
        }
    }
    return sum;
}
#endif

#ifdef Q3_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes

    // Bytes 108-109: f16 scale 'd'
    let d = load_f16_as_f32_at_src0(block_byte_base + 108);

    // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
    // and 2-bits from the last 4 bytes
    // Bytes 96-107: 12 bytes of scales (3 u32s)
    let kmask1: u32 = 0x03030303;
    let kmask2: u32 = 0x0f0f0f0f;
    var scale_vals: array<u32, 4>;
    scale_vals[0] = load_u32_at_src0(block_byte_base + 96);
    scale_vals[1] = load_u32_at_src0(block_byte_base + 100);
    scale_vals[2] = load_u32_at_src0(block_byte_base + 104);

    var tmp: u32 = scale_vals[2];
    scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
    scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
    scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
    scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);

    // Bytes 0-31: 32 bytes of hmask (8 u32s)
    var hmask_vals: array<u32, 8>;
    for (var i: u32 = 0; i < 8; i++) {
        hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
    }

    // Bytes 32-95: 64 bytes of qs (16 u32s)
    var qs_vals: array<u32, 16>;
    for (var i: u32 = 0u; i < 16; i++) {
        qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4);
    }

    var sum = 0.0;
    var src1_i = src1_idx_base + offset * 256;
    var is: u32 = 0;
    var m: u32 = 1;
    // 2 halves of the block (128 elements each)
    for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
        // 4 groups (each group has 2 blocks of 16 elements)
        for (var shift: u32 = 0; shift < 8; shift += 2) {
            // 2 blocks
            for (var k: u32 = 0; k < 32; k += 16) {
                let sc = get_byte(scale_vals[is / 4], is % 4);
                is++;
                let dl = d * (f32(sc) - 32.0);
                for (var l: u32 = 0u; l < 16u; l++) {
                    let q_idx = q_b_idx + k + l;
                    let hm_idx = k + l;
                    let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4);
                    let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4);
                    let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
                    let qs_val = (q_byte >> shift) & 3;
                    sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
                    src1_i++;
                }
            }
            m <<= 1;
        }
    }
    return sum;
}
#endif

#ifdef Q4_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block = src0[src0_idx_base + offset];
    let d = f32(block.d);
    let m = f32(block.dmin);
    var sum = 0.0;
    var src1_i = src1_idx_base + offset * 256;
    var is: u32 = 0;
    // 2 blocks each iteration
    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
        for (var shift: u32 = 0; shift < 8; shift += 4) {
            let scale_min = get_scale_min(is, block.scales);
            is++;
            let dl = d * scale_min.x;
            let ml = m * scale_min.y;
            for (var l: u32 = 0; l < 32; l++) {
                let q_idx = q_b_idx + l;
                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
                let qs_val = (q_byte >> shift) & 0xF;
                sum += (f32(qs_val) * dl - ml) * src1[src1_i];
                src1_i++;
            }
        }
    }
    return sum;
}
#endif

#ifdef Q5_K
// 8 blocks of 32 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block = src0[src0_idx_base + offset];
    let d = f32(block.d);
    let m = f32(block.dmin);
    var sum = 0.0;
    var src1_i = src1_idx_base + offset * 256;
    var is: u32 = 0;
    var u: u32 = 1;
    // 2 blocks each iteration
    for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
        for (var shift: u32 = 0; shift < 8; shift += 4) {
            let scale_min = get_scale_min(is, block.scales);
            is++;
            let dl = d * scale_min.x;
            let ml = m * scale_min.y;
            for (var l: u32 = 0; l < 32; l++) {
                let q_idx = q_b_idx + l;
                let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4);
                let qh_byte = get_byte(block.qh[l / 4], l % 4);
                let qs_val = (q_byte >> shift) & 0xF;
                let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
                sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
               src1_i++;
            }
            u <<= 1;
        }
    }
    return sum;
}
#endif

#ifdef Q6_K
// 16 blocks of 16 elements each
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes

    // Bytes 208-209: f16 scale 'd'
    let d = load_f16_as_f32_at_src0(block_byte_base + 208);

    // Bytes 0-127: 128 bytes of ql (32 u32s)
    var ql_vals: array<u32, 32>;
    for (var i: u32 = 0; i < 32; i++) {
        ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
    }

    // Bytes 128-191: 64 bytes of qh (16 u32s)
    var qh_vals: array<u32, 16>;
    for (var i: u32 = 0; i < 16; i++) {
        qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4);
    }

    // Bytes 192-207: 16 bytes of scales (4 u32s)
    var scale_vals: array<u32, 4>;
    for (var i: u32 = 0; i < 4; i++) {
        scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4);
    }

    var sum = 0.0;
    var src1_i = src1_idx_base + offset * 256;
    var qh_b_idx: u32 = 0;
    var sc_b_idx: u32 = 0;
    for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
        for (var l: u32 = 0; l < 32; l++) {
            let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4);
            let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4);
            let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4);

            let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
            let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
            let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
            let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;

            let is = l/16;
            let is1 = sc_b_idx + is;
            let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4);
            let is2 = sc_b_idx + is + 2;
            let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4);
            let is3 = sc_b_idx + is + 4;
            let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4);
            let is4 = sc_b_idx + is + 6;
            let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4);

            sum += d * f32(sc1) * q1 * src1[src1_i + l];
            sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
            sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
            sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
        }
        src1_i += 128;
        qh_b_idx += 32;
        sc_b_idx += 8;
    }
    return sum;
}
#endif

#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var src1_i = src1_idx_base + offset * 256;
    var sum = 0.0;
    for (var ib: u32 = 0; ib < 32; ib += 4) {
        let aux0_offset = block_byte_base + 2 + ib * 2;
        let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
        let aux0 = load_u32_at_src0(aux0_offset);
        let aux1 = load_u32_at_src0(aux1_offset);
        let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
        for (var l: u32 = 0; l < 4; l++) {
            let ig = get_byte(aux0, l) * 8;
            let is = (aux1 >> (7 * l)) & 127;
            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
            for (var j: u32 = 0; j < 8; j++) {
                let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4);
                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
                sum += db * f32(g) * m * src1[src1_i];
                src1_i++;
            }
        }
    }
    return sum;
}
#endif

#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var src1_i = src1_idx_base + offset * 256;

    var scale_vals = array<u32, 2>(
        load_u32_at_src0(block_byte_base + 66),
        load_u32_at_src0(block_byte_base + 70)
    );

    var sum = 0.0;
    for (var ib: u32 = 0; ib < 32; ib += 4) {
        let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4);
        let db = array<f32, 2>(
            d * (0.5 + f32(s & 0xF)) * 0.25,
            d * (0.5 + f32(s >> 4)) * 0.25
        );
        for (var l: u32 = 0; l < 4; l++) {
            let qs_offset = block_byte_base + 2 + (ib + l) * 2;
            let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF;
            let ig = (qs_val & 511) * 8;
            let is = qs_val >> 9;
            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
            let dl = db[l/2];
            for (var j: u32 = 0; j < 8; j++) {
                let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4);
                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
                sum += dl * f32(g) * m * src1[src1_i];
                src1_i++;
            }
        }
    }
    return sum;
}
#endif

#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var src1_i = src1_idx_base + offset * 256;

    var qs_vals : array<u32, 16>;
    for (var i: u32 = 0; i < 16; i++) {
        qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
    }

    var qh_vals: array<u32, 2>;
    qh_vals[0] = load_u32_at_src0(block_byte_base + 66);
    qh_vals[1] = load_u32_at_src0(block_byte_base + 70);

    var scale_vals: array<u32, 2>;
    scale_vals[0] = load_u32_at_src0(block_byte_base + 74);
    scale_vals[1] = load_u32_at_src0(block_byte_base + 78);

    var sum = 0.0;
    for (var ib: u32 = 0; ib < 8; ib ++) {
        let s = get_byte(scale_vals[ib / 4], ib % 4);
        let db = array<f32, 2>(
            d * (0.5 + f32(s & 0xF)) * 0.25,
            d * (0.5 + f32(s >> 4)) * 0.25
        );
        let qs_w = qs_vals[ib];
        for (var l: u32 = 0; l < 4; l++) {
            let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300;
            let ig = (get_byte(qs_w, l) | qh_b) * 8;
            let signs = get_byte(qs_vals[ib + 8], l);
            let dl = db[l/2];
            for (var j: u32 = 0; j < 8; j++) {
                let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4);
                let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0);
                sum += dl * f32(g) * m * src1[src1_i];
                src1_i++;
            }
        }
    }
    return sum;
}
#endif

#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var src1_i = src1_idx_base + offset * 256;
    var sum = 0.0;
    for (var ib: u32 = 0; ib < 16; ib += 2) {
        let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
        let sc_sign = load_u32_at_src0(sc_sign_offset);
        let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
        for (var l: u32 = 0; l < 4; l++) {
            let is = (sc_sign >> (7 * l)) & 127;
            let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
            let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
            let ig1 = get_byte(ig_val, 0);
            let ig2 = get_byte(ig_val, 1);
            for (var j: u32 = 0; j < 4; j++) {
                let g1 = get_byte(iq3xxs_grid[ig1], j);
                let g2 = get_byte(iq3xxs_grid[ig2], j);
                let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
                let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
                sum += db * f32(g1) * m1 * src1[src1_i];
                sum += db * f32(g2) * m2 * src1[src1_i + 4];
                src1_i++;
            }
            src1_i += 4;
        }
    }
    return sum;
}
#endif

#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var src1_i = src1_idx_base + offset * 256;

    var qh_vals = array<u32, 2>(
        load_u32_at_src0(block_byte_base + 66),
        load_u32_at_src0(block_byte_base + 70)
    );

    var sign_vals: array<u32, 8>;
    for (var i: u32 = 0; i < 8; i++) {
        sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4);
    }

    var scale_vals = load_u32_at_src0(block_byte_base + 106);

    var sum = 0.0;
    for (var ib: u32 = 0; ib < 4; ib++) {
        let s = get_byte(scale_vals, ib);
        let db = array<f32, 2>(
            d * (1.0 + 2.0 * f32(s & 0xF)),
            d * (1.0 + 2.0 * f32(s >> 4))
        );
        for (var k: u32 = 0; k < 2; k++) {
            let dl = db[k];
            let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k);
            let sign_w = sign_vals[ib * 2 + k];
            for (var l: u32 = 0; l < 4; l++) {
                let signs = get_byte(sign_w, l);
                let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
                let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
                let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
                for (var j: u32 = 0; j < 4; j++) {
                    let g1 = get_byte(iq3s_grid[ig1], j);
                    let g2 = get_byte(iq3s_grid[ig2], j);
                    let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0);
                    let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0);
                    sum += dl * f32(g1) * m1 * src1[src1_i];
                    sum += dl * f32(g2) * m2 * src1[src1_i + 4];
                    src1_i++;
                }
                src1_i += 4;
            }
        }
    }
    return sum;
}
#endif

#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var src1_i = src1_idx_base + offset * 256;
    var sum = 0.0;
    for (var ib: u32 = 0; ib < 8; ib++) {
        let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF;
        let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
        let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
        let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4);
        for (var l: u32 = 0; l < 4; l++) {
            let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
            for (var j: u32 = 0; j < 8; j++) {
                let gw = iq1_grid[(ig + j) / 16];
                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
                let gs = bitcast<i32>(g << 30) >> 30;
                sum += dl * (f32(gs) + delta) * src1[src1_i];
                src1_i++;
            }
        }
    }
    return sum;
}
#endif


#ifdef IQ1_M
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block = src0[src0_idx_base + offset];

    let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000);
    let d = f32(bitcast<vec2<f16>>(scale).x);
    var src1_i = src1_idx_base + offset * 256;
    var sum = 0.0;
    for (var ib: u32 = 0; ib < 8; ib++) {
        let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF;
        let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7;
        let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7;
        var dl = array<f32, 2>(
            d * f32(2 * s1 + 1),
            d * f32(2 * s2 + 1)
        );

        let qh = block.qh[ib / 2] >> (16 * (ib % 2));
        var idx = array<u32, 4>(
            get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700),
            get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700),
            get_byte(block.qs[ib], 2) | ((qh) & 0x700),
            get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700)
        );
        var delta = array<f32, 4>(
            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0),
            select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0),
            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0),
            select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0)
        );
        for (var l: u32 = 0; l < 4; l++) {
            let ig = idx[l] * 8;
            for (var j: u32 = 0; j < 8; j++) {
                let gw = iq1_grid[(ig + j) / 16];
                let g = (gw >> (((ig + j) % 16) * 2)) & 3;
                let gs = bitcast<i32>(g << 30) >> 30;
                sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i];
                src1_i++;
            }
        }
    }
    return sum;
}
#endif

#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
    let d = load_f16_as_f32_at_src0(block_byte_base);
    var src1_i = src1_idx_base + offset * 32;
    var sum = 0.0;
    var qs: array<u32, 4>;
    for (var i: u32 = 0; i < 4; i++) {
        qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
    }
    for (var j: u32 = 0; j < 16; j++) {
        let qsb = get_byte(qs[j / 4], j % 4);
        sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
        sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
        src1_i++;
    }
    return sum;
}
#endif

#ifdef IQ4_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
    let block = src0[src0_idx_base + offset];
    let d = unpack2x16float(block.d_scales_h)[0];
    let scales_h = block.d_scales_h >> 16;
    var src1_i = src1_idx_base + offset * 256;
    var sum = 0.0;
    for (var ib: u32 = 0; ib < 8; ib++) {
        let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4);
        let dl = d * (f32(ls) - 32.0);
        for (var j: u32 = 0; j < 16; j++) {
            let iqs = ib * 16 + j;
            let qsb = get_byte(block.qs[iqs / 4], iqs % 4);
            sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i];
            sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16];
            src1_i++;
        }
        src1_i += 16;
    }
    return sum;
}
#endif

struct MulMatParams {
    offset_src0: u32, // in elements/blocks
    offset_src1: u32, // in elements/blocks
    offset_dst: u32, // in elements/blocks
    m: u32,
    n: u32,
    k: u32,
    // all strides are in elements/blocks
    stride_01: u32,
    stride_11: u32,
    stride_02: u32,
    stride_12: u32,
    stride_03: u32,
    stride_13: u32,

    bs02: u32,
    bs03: u32,
    broadcast2: u32,
    broadcast3: u32
};

@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns

@group(0) @binding(3) var<uniform> params: MulMatParams;

@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) local_id: vec3<u32>,
        @builtin(workgroup_id) wg_id: vec3<u32>,
        @builtin(num_workgroups) num_wg: vec3<u32>) {
    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
    let global_idx = wg_linear * 256u + local_id.x;

    let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
    if (global_idx >= total) {
        return;
    }

    let dst2_stride = params.m * params.n;
    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;

    let dst3_idx = global_idx / dst3_stride;
    let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
    let src13_idx = dst3_idx; // src1 is not broadcast
    let dst3_rem = global_idx % dst3_stride;

    let dst2_idx = dst3_rem / dst2_stride;
    let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
    let src12_idx = dst2_idx; // src1 is not broadcast

    let dst2_rem = dst3_rem % dst2_stride;

    let row = dst2_rem / params.m; // output row
    let col = dst2_rem % params.m; // output column

    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01;
    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;

    var sum = 0.0;
    for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
        sum += multiply_add(src0_idx_base, src1_idx_base, i);
    }
    dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
}
