//------------------------------------------------------------------------------
// This file is contains kernels for data conversion.
// These kernels are used when loading the model, so its performance is less
// important.
//------------------------------------------------------------------------------
#pragma OPENCL EXTENSION cl_khr_fp16 : enable

#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif

#define QK4_0                   32
#define QR4_0                   2
#define QK4_1                   32
#define QR4_1                   2
#define QK5_0                   32
#define QR5_0                   2
#define QK5_1                   32
#define QR5_1                   2
#define QK8_0                   32
#define QR8_0                   1
#define QK_K                    256
#define K_SCALE_SIZE            (3 * QK_K / 64)
#define K_QUANTS_PER_ITERATION  2

typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;

//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
    half d;
    uint8_t qs[QK4_0 / 2];
};

//------------------------------------------------------------------------------
// block_q4_1
//------------------------------------------------------------------------------
struct block_q4_1 {
    half d; // delta
    half m; // min
    uchar qs[QK4_1 / 2]; // nibbles / quants
};

//------------------------------------------------------------------------------
// block_q4_k
//------------------------------------------------------------------------------
struct block_q4_K {
    half d; // delta
    half dm; // min
    uchar s[K_SCALE_SIZE];
    uchar q[QK_K / 2]; // nibbles / quants
};

//------------------------------------------------------------------------------
// block_q5_k
//------------------------------------------------------------------------------
struct block_q5_K {
    half d; // delta
    half dm; // min
    uchar s[K_SCALE_SIZE];
    uchar qh[QK_K / 8];
    uchar qs[QK_K / 2]; // nibbles / quants
};

//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
struct block_q6_K {
    uint8_t ql[QK_K/2];      // quants, lower 4 bits
    uint8_t qh[QK_K/4];      // quants, upper 2 bits
    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
    half d;                  // super-block scale
};

//------------------------------------------------------------------------------
// block_iq4_nl
//------------------------------------------------------------------------------
#define QK4_NL 32

struct block_iq4_nl
{
    half d;
    uint8_t qs[QK4_NL / 2];
};

//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_0(
    global struct block_q4_0 * src0,
    global uchar * dst_q,
    global half  * dst_d
) {
    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
    global half  * d = (global half *) dst_d + get_global_id(0);

    *d = b->d;

    for (int i = 0; i < QK4_0/2; ++i) {
        q[i] = b->qs[i];
    }
}

kernel void kernel_restore_block_q4_0(
    global uchar * src_q,
    global half  * src_d,
    global struct block_q4_0 * dst
) {
    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
    global half  * d = (global half *) src_d + get_global_id(0);

    b->d = *d;
    for (int i = 0; i < QK4_0/2; ++i) {
        b->qs[i] = q[i];
    }
}

//------------------------------------------------------------------------------
// kernel_convert_block_q4_0_noshuffle
// Flatten q4_0 weights and unshuffle the bits
//------------------------------------------------------------------------------

kernel void kernel_convert_block_q4_0_noshuffle(
    global struct block_q4_0 * src0,
    global uchar * dst_q,
    global half  * dst_d
) {
    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
    global half  * d = (global half *) dst_d + get_global_id(0);

    *d = b->d;
    for (int i = 0; i < QK4_0/4; ++i) {
        uchar x0 = b->qs[2*i + 0];
        uchar x1 = b->qs[2*i + 1];

        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
        q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);

#ifdef ADRENO_GPU
        // Workaround for adreno - must have the following printf statement for
        // the kernel to work properly. Otherwise it produces incorrect result.
        // convert_uchar above also seems necessary.
        // Compare against a large number so that it does not print anything.
        // get_sub_group_local_id() also works.
        if (get_global_id(0) == 65536*4096) {
            printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
        }
#endif
    }
}

kernel void kernel_restore_block_q4_0_noshuffle(
    global uchar * src_q,
    global half  * src_d,
    global struct block_q4_0 * dst,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
    global half  * d = (global half *) src_d + get_global_id(0);

    b->d = *d;
    for (int i = 0; i < QK4_0/4; ++i) {
        uchar x0 = q[i + 0      ] ;
        uchar x1 = q[i + QK4_0/4];

        b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
        b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
    }
}

//------------------------------------------------------------------------------
// kernel_convert_block_q4_1
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_1(
    global struct block_q4_1 * src0,
    global uchar * dst_q,
    global half  * dst_d,
    global half  * dst_m
) {
    global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
    global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
    global half  * d = (global half *) dst_d + get_global_id(0);
    global half  * m = (global half *) dst_m + get_global_id(0);

    *d = b->d;
    *m = b->m;

    for (int i = 0; i < QK4_1/2; ++i) {
        q[i] = b->qs[i];
    }
}

kernel void kernel_restore_block_q4_1(
    global uchar * src_q,
    global half  * src_d,
    global half  * src_m,
    global struct block_q4_1 * dst
) {
    global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
    global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
    global half  * d = (global half *) src_d + get_global_id(0);
    global half  * m = (global half *) src_m + get_global_id(0);

    b->d = *d;
    b->m = *m;
    for (int i = 0; i < QK4_1/2; ++i) {
        b->qs[i] = q[i];
    }
}

kernel void kernel_convert_block_q4_1_noshuffle(
    global struct block_q4_1 * src0,
    global uchar * dst_q,
    global half  * dst_d,
    global half  * dst_m
) {
    global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
    global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
    global half  * d = (global half *) dst_d + get_global_id(0);
    global half  * m = (global half *) dst_m + get_global_id(0);

    *d = b->d;
    *m = b->m;
    for (int i = 0; i < QK4_1/4; ++i) {
        uchar x0 = b->qs[2*i + 0];
        uchar x1 = b->qs[2*i + 1];

        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
        q[i + QK4_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);

#ifdef ADRENO_GPU
        if (get_global_id(0) == 65536*4096) {
            printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
        }
#endif
    }
}

kernel void kernel_restore_block_q4_1_noshuffle(
    global uchar * src_q,
    global half  * src_d,
    global half  * src_m,
    global struct block_q4_1 * dst,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
    global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
    global half  * d = (global half *) src_d + get_global_id(0);
    global half  * m = (global half *) src_m + get_global_id(0);

    b->d = *d;
    b->m = *m;
    for (int i = 0; i < QK4_1/4; ++i) {
        uchar x0 = q[i + 0      ] ;
        uchar x1 = q[i + QK4_1/4];

        b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
        b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
    }
}

//------------------------------------------------------------------------------
// block_mxfp4
//------------------------------------------------------------------------------
#define QK_MXFP4 32
struct block_mxfp4 {
    uchar e; // E8M0
    uchar qs[QK_MXFP4 / 2];
};

//------------------------------------------------------------------------------
// kernel_convert_block_mxfp4
// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_mxfp4(
    global struct block_mxfp4 * src0,
    global uchar * dst_q,
    global uchar * dst_e
) {
    global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0);
    global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0);
    global uchar * e = (global uchar *) dst_e + get_global_id(0);

    *e = b->e;

    for (int i = 0; i < QK_MXFP4 / 2; ++i) {
        q[i] = b->qs[i];
    }
}

kernel void kernel_convert_block_mxfp4_trans(
    global struct block_mxfp4 * src0,
    __global uint4 * dst_q,
    __global uchar * dst_e,
    uint ne00,
    uint ne01
) {
    int i00 = get_global_id(1);
    uint i01 = get_global_id(0);
    uint i02 = get_global_id(2);

    uint ne00_blk = ne00 / QK_MXFP4;
    uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
    uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;

    global struct block_mxfp4 * b = src0 + src_blk_offset;

    dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
    dst_e[dst_blk_offset] = b->e;
}

kernel void kernel_restore_block_mxfp4(
    global uchar * src_q,
    global half  * src_e,
    global struct block_mxfp4 * dst
) {
    global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0);
    global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0);
    global uchar * e = (global uchar *) src_e + get_global_id(0);

    b->e = *e;
    for (int i = 0; i < QK_MXFP4 / 2; ++i) {
        b->qs[i] = q[i];
    }
}

kernel void kernel_restore_block_mxfp4_trans(
    __global uint4 * src_q,
    __global uchar * src_e,
    global struct block_mxfp4 * dst,
    uint ne00,
    uint ne01
) {
    int i00 = get_global_id(1);
    uint i01 = get_global_id(0);
    uint i02 = get_global_id(2);

    uint ne00_blk = ne00 / QK_MXFP4;
    uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
    uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;

    global struct block_mxfp4 * b = dst + dst_blk_offset;

    ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
    b->e = src_e[src_blk_offset];
}

kernel void kernel_convert_block_mxfp4_trans4_ns(
    global struct block_mxfp4 * src0,
    __global uint * dst_q,
    __global uchar * dst_e,
    uint ne00,
    uint ne01
) {
    uint i00 = get_global_id(1);
    uint i01 = get_global_id(0);
    uint i02 = get_global_id(2);

    uint ne00_blk = ne00 / QK_MXFP4;
    uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
    uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;

    global struct block_mxfp4 * b = src0 + src_blk_offset;
    dst_e[dst_blk_offset] = b->e;

    // extract quantization and unshuffle
    ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0];

    ushort8 post_block = (ushort8)(0);

    uchar * pre_block_ptr = (uchar *)(&pre_block);
    uchar * post_block_ptr = (uchar *)(&post_block);

    for (int i = 0; i < QK_MXFP4 / 4; ++i) {
        uchar x0 = pre_block_ptr[2*i + 0];
        uchar x1 = pre_block_ptr[2*i + 1];

        post_block_ptr[i + 0        ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
        post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
    }

    uint4 q_block = as_uint4(post_block);

    uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
    dst_q[offset] = q_block.x;
    dst_q[offset + ne01] = q_block.y;
    dst_q[offset + ne01 * 2] = q_block.z;
    dst_q[offset + ne01 * 3] = q_block.w;
}

kernel void kernel_restore_block_mxfp4_trans4_ns(
    __global uint * src_q,
    __global uchar * src_e,
    __global struct block_mxfp4 * dst0,
    uint ne00,
    uint ne01
) {
    uint i00 = get_global_id(1);
    uint i01 = get_global_id(0);
    uint i02 = get_global_id(2);

    uint ne00_blk = ne00 / QK_MXFP4;
    uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
    uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;

    __global struct block_mxfp4 * b = dst0 + dst_blk_offset;
    b->e = src_e[src_d_offset];

    // collect transposed quantization parts for a block
    uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01;
    uint4 q_block;
    q_block.x = src_q[src_q_offset];
    q_block.y = src_q[src_q_offset + ne01];
    q_block.z = src_q[src_q_offset + ne01 * 2];
    q_block.w = src_q[src_q_offset + ne01 * 3];

    ushort8 post_block = as_ushort8(q_block);
    ushort8 pre_block = (ushort8)(0);

    uchar * pre_block_ptr = (uchar *)(&pre_block);
    uchar * post_block_ptr = (uchar *)(&post_block);

    for (int i = 0; i < QK_MXFP4 / 4; ++i) {
        uchar x0 = post_block_ptr[i + 0];
        uchar x1 = post_block_ptr[i + QK_MXFP4 / 4];

        pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
        pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
    }

    ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block;
}


//------------------------------------------------------------------------------
// block_q8_0
//------------------------------------------------------------------------------
typedef struct {
    half d;       // delta
    char qs[QK8_0]; // quants
} block_q8_0;

kernel void kernel_convert_block_q8_0(
    global block_q8_0 * src0,
    global uchar * dst_q,
    global half  * dst_d
) {
    global block_q8_0 * b = (global block_q8_0 *) src0 + get_global_id(0);
    global uchar      * q = (global uchar *) dst_q + QK8_0*get_global_id(0);
    global half       * d = (global half *) dst_d + get_global_id(0);

    *d = b->d;

    for (int i = 0; i < QK8_0; ++i) {
        q[i] = b->qs[i];
    }
}

kernel void kernel_restore_block_q8_0(
    global uchar * src_q,
    global half  * src_d,
    global block_q8_0 * dst
) {
    global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0);
    global uchar      * q = (global uchar *) src_q + QK8_0*get_global_id(0);
    global half       * d = (global half *) src_d + get_global_id(0);

    b->d = *d;
    for (int i = 0; i < QK8_0; ++i) {
        b->qs[i] = q[i];
    }
}

kernel void kernel_restore_block_q8_0_trans(
    global uchar * src_q,
    global half  * src_d,
    global block_q8_0 * dst,
    uint ne00,
    uint ne01
){
    uint num_blk_per_row = ne00 / QK8_0;

    global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row;
    global uchar      * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed
    global half       * d = (global half *) src_d + get_global_id(0);

    for (uint blk = 0; blk < num_blk_per_row; blk++) {
        b->d = *d;

        for (uint i = 0; i < QK8_0; i+=4) {
            b->qs[i]   = q[0];
            b->qs[i+1] = q[1];
            b->qs[i+2] = q[2];
            b->qs[i+3] = q[3];

            q += 4 * ne01; // M stride
        }

        d += ne01;

        b++;
    }
}

//------------------------------------------------------------------------------
// kernel_convert_block_q4_K
// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
// Each thread processes a super block.
// Mask args are just to keep the signature consistent with the no-shuffle
// version and they are not used in this kernel.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_K(
    global struct block_q4_K * src0,
    global uchar * dst_q,
    global uchar * dst_s,
    global half  * dst_d,
    global half  * dst_dm,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0);
    global uchar * q  = (global uchar *) dst_q  + QK_K/2*get_global_id(0);
    global uchar * s  = (global uchar *) dst_s  + K_SCALE_SIZE*get_global_id(0);
    global half  * d  = (global half  *) dst_d  + get_global_id(0);
    global half  * dm = (global half  *) dst_dm + get_global_id(0);

    *d  = b->d;
    *dm = b->dm;

    for (int i = 0; i < QK_K/2; ++i) {
        q[i] = b->q[i];
    }
    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        s[i] = b->s[i];
    }
}

// Restore block_q4_K from flattened arrays.
// Each thread processes a super block.
// Mask args are just to keep the signature consistent with the no-shuffle ones.
kernel void kernel_restore_block_q4_K(
    global uchar * src_q,
    global uchar * src_s,
    global half  * src_d,
    global half  * src_dm,
    global struct block_q4_K * dst,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0);
    global uchar * q  = (global uchar *) src_q  + QK_K/2*get_global_id(0);
    global uchar * s  = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0);
    global half  * d  = (global half  *) src_d  + get_global_id(0);
    global half  * dm = (global half  *) src_dm  + get_global_id(0);

    b->d  = *d;
    b->dm = *dm;

    for (int i = 0; i < QK_K/2; ++i) {
        b->q[i] = q[i];
    }
    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        b->s[i] = s[i];
    }
}

kernel void kernel_convert_block_q4_K_noshuffle(
    global struct block_q4_K * src0,
    global uchar * dst_q,
    global uchar * dst_s,
    global half  * dst_d,
    global half  * dst_dm,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0);
    global uchar * q  = (global uchar *) dst_q  + QK_K/2 * get_global_id(0);
    global uchar * s  = (global uchar *) dst_s  + K_SCALE_SIZE * get_global_id(0);
    global half  * d  = (global half  *) dst_d  + get_global_id(0);
    global half  * dm = (global half  *) dst_dm + get_global_id(0);

    *d  = b->d;
    *dm = b->dm;

    for (int i = 0; i < QK_K / 64; ++i) {
        for (int j = 0; j < 16; ++j) {
            uchar x0 = b->q[i*32 + 2*j];
            uchar x1 = b->q[i*32 + 2*j + 1];
            q[i*32 + j]      = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
            q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4)   | convert_uchar(x1 & mask_F0);
        }
    }

    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        s[i] = b->s[i];
    }
}

kernel void kernel_restore_block_q4_K_noshuffle(
    global uchar * src_q,
    global uchar * src_s,
    global half  * src_d,
    global half  * src_dm,
    global struct block_q4_K * dst,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0);
    global uchar * q  = (global uchar *) src_q  + QK_K/2 * get_global_id(0);
    global uchar * s  = (global uchar *) src_s  + K_SCALE_SIZE * get_global_id(0);
    global half  * d  = (global half  *) src_d  + get_global_id(0);
    global half  * dm = (global half  *) src_dm + get_global_id(0);

    b->d  = *d;
    b->dm = *dm;

    for (int i = 0; i < QK_K / 64; ++i) {
        for (int j = 0; j < 16; ++j) {
            uchar lo = q[i*32 + j];
            uchar hi = q[i*32 + j + 16];
            b->q[i*32 + 2*j]     = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
            b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
        }
    }

    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        b->s[i] = s[i];
    }
}

//------------------------------------------------------------------------------
// kernel_convert_block_q5_K
// Convert the block_q5_K format to 5 separate arrays (AOS -> SOA).
// Each thread processes a super block.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q5_K(
    global struct block_q5_K * src0,
    global uchar * dst_q,
    global uchar * dst_qh,
    global uchar * dst_s,
    global half  * dst_d,
    global half  * dst_dm,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q5_K * b  = (global struct block_q5_K *) src0 + get_global_id(0);
    global uchar * q  = (global uchar *) dst_q  + QK_K/2*get_global_id(0);
    global uchar * qh = (global uchar *) dst_qh + QK_K/8*get_global_id(0);
    global uchar * s  = (global uchar *) dst_s  + K_SCALE_SIZE*get_global_id(0);
    global half  * d  = (global half  *) dst_d  + get_global_id(0);
    global half  * dm = (global half  *) dst_dm + get_global_id(0);

    *d  = b->d;
    *dm = b->dm;

    for (int i = 0; i < QK_K/2; ++i) {
        q[i] = b->qs[i];
    }
    for (int i = 0; i < QK_K/8; ++i) {
        qh[i] = b->qh[i];
    }
    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        s[i] = b->s[i];
    }
}

// Restore block_q5_K from flattened arrays.
// Each thread processes a super block.
kernel void kernel_restore_block_q5_K(
    global uchar * src_q,
    global uchar * src_qh,
    global uchar * src_s,
    global half  * src_d,
    global half  * src_dm,
    global struct block_q5_K * dst,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q5_K * b  = (global struct block_q5_K *) dst + get_global_id(0);
    global uchar * q  = (global uchar *) src_q  + QK_K/2*get_global_id(0);
    global uchar * qh = (global uchar *) src_qh + QK_K/8*get_global_id(0);
    global uchar * s  = (global uchar *) src_s  + K_SCALE_SIZE*get_global_id(0);
    global half  * d  = (global half  *) src_d  + get_global_id(0);
    global half  * dm = (global half  *) src_dm + get_global_id(0);

    b->d    = *d;
    b->dm = *dm;

    for (int i = 0; i < QK_K/2; ++i) {
        b->qs[i] = q[i];
    }
    for (int i = 0; i < QK_K/8; ++i) {
        b->qh[i] = qh[i];
    }
    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        b->s[i] = s[i];
    }
}

kernel void kernel_convert_block_q5_K_noshuffle(
    global struct block_q5_K * src0,
    global uchar * dst_q,
    global uchar * dst_qh,
    global uchar * dst_s,
    global half  * dst_d,
    global half  * dst_dm,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q5_K * b  = (global struct block_q5_K *) src0 + get_global_id(0);
    global uchar * q  = (global uchar *) dst_q  + QK_K/2       * get_global_id(0);
    global uchar * qh = (global uchar *) dst_qh + QK_K/8       * get_global_id(0);
    global uchar * s  = (global uchar *) dst_s  + K_SCALE_SIZE * get_global_id(0);
    global half  * d  = (global half  *) dst_d  + get_global_id(0);
    global half  * dm = (global half  *) dst_dm + get_global_id(0);

    *d  = b->d;
    *dm = b->dm;

    for (int i = 0; i < QK_K / 64; ++i) {
        for (int j = 0; j < 16; ++j) {
            uchar x0 = b->qs[i*32 + 2*j];
            uchar x1 = b->qs[i*32 + 2*j + 1];
            q[i*32 + j]      = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
            q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
        }
    }

    for (int l = 0; l < QK_K/8; ++l) {
        uchar x0 = 0;
        for (int i = 0; i < 8; ++i) {
            x0 |= ((b->qh[(l%4)*8+i] >> (l/4)) & 0x01) << i;
        }
        qh[l] = x0;
    }

    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        s[i] = b->s[i];
    }
}

kernel void kernel_restore_block_q5_K_noshuffle(
    global uchar * src_q,
    global uchar * src_qh,
    global uchar * src_s,
    global half  * src_d,
    global half  * src_dm,
    global struct block_q5_K * dst,
    uchar mask_0F,
    uchar mask_F0
) {
    global struct block_q5_K * b  = (global struct block_q5_K *) dst + get_global_id(0);
    global uchar * q  = (global uchar *) src_q  + QK_K/2       * get_global_id(0);
    global uchar * qh = (global uchar *) src_qh + QK_K/8       * get_global_id(0);
    global uchar * s  = (global uchar *) src_s  + K_SCALE_SIZE * get_global_id(0);
    global half  * d  = (global half  *) src_d  + get_global_id(0);
    global half  * dm = (global half  *) src_dm + get_global_id(0);

    b->d  = *d;
    b->dm = *dm;

    for (int i = 0; i < QK_K / 64; ++i) {
        for (int j = 0; j < 16; ++j) {
            uchar lo = q[i*32 + j];
            uchar hi = q[i*32 + j + 16];
            b->qs[i*32 + 2*j]     = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4));
            b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0));
        }
    }

    for (int g = 0; g < 4; ++g) {
        for (int i = 0; i < 8; ++i) {
            uchar x0 = 0;
            for (int k = 0; k < 8; ++k) {
                x0 |= ((qh[4*k+g] >> i) & 0x01) << k;
            }
            b->qh[g*8+i] = x0;
        }
    }

    for (int i = 0; i < K_SCALE_SIZE; ++i) {
        b->s[i] = s[i];
    }
}

//------------------------------------------------------------------------------
// kernel_convert_block_q6_K
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
// Each thread processes a super block.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q6_K(
    global struct block_q6_K * src0,
    global uchar * dst_ql,
    global uchar * dst_qh,
    global char  * dst_s,
    global half  * dst_d,
    uchar          mask_lsb_8,
    ulong          n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);
    global half  * d  = (global half  *) dst_d  + get_global_id(0);

    *d = b->d;

    for (int i = 0; i < QK_K/2; ++i) {
        ql[i] = b->ql[i];
    }
    for (int i = 0; i < QK_K/4; ++i) {
        qh[i] = b->qh[i];
    }
    for (int i = 0; i < QK_K/16; ++i) {
        s[i] = b->scales[i];
    }
}

// Restore block_q6_K from flattened arrays.
// Each thread processes a super block.
kernel void kernel_restore_block_q6_K(
    global uchar * dst_ql,
    global uchar * dst_qh,
    global char  * dst_s,
    global half  * dst_d,
    global struct block_q6_K * dst,
    uchar mask_lsb_8,
    ulong n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);
    global half  * d  = (global half  *) dst_d  + get_global_id(0);

    b->d = *d;

    for (int i = 0; i < QK_K/2; ++i) {
        b->ql[i] = ql[i];
    }
    for (int i = 0; i < QK_K/4; ++i) {
        b->qh[i] = qh[i];
    }
    for (int i = 0; i < QK_K/16; ++i) {
        b->scales[i] = s[i];
    }
}

kernel void kernel_convert_block_q6_K_noshuffle(
    global struct block_q6_K * src0,
    global uchar * dst_ql,
    global uchar * dst_qh,
    global char  * dst_s,
    global half  * dst_d,
    uchar          mask_lsb_8,
    ulong          n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);
    global half  * d  = (global half  *) dst_d  + get_global_id(0);

    *d = b->d;

    for (int i = 0; i < QK_K/2/4; ++i) {
        uchar x0 = b->ql[i*2 + 0] & mask_lsb_8;
        uchar x1 = b->ql[i*2 + 1] & mask_lsb_8;
        ql[i +  0] = (x0 & 0x0F)        | ((x1 & 0x0F) << 4);
        ql[i + 32] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0);

        uchar x2 = b->ql[i*2 + 0 + 64] & mask_lsb_8;
        uchar x3 = b->ql[i*2 + 1 + 64] & mask_lsb_8;
        ql[i + 64] = (x2 & 0x0F)        | ((x3 & 0x0F) << 4);
        ql[i + 96] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0);
    }

    for (int i = 0; i < QK_K/4/8; ++i) {
        uchar x0 = b->qh[i*4 + 0] & mask_lsb_8;
        uchar x1 = b->qh[i*4 + 1] & mask_lsb_8;
        uchar x2 = b->qh[i*4 + 2] & mask_lsb_8;
        uchar x3 = b->qh[i*4 + 3] & mask_lsb_8;
        qh[i +  0] = (x0 & 0x03)        | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6);
        qh[i +  8] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C)        | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4);
        qh[i + 16] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30)        | ((x3 & 0x30) << 2);
        qh[i + 24] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0);

        uchar x4 = b->qh[i*4 + 0 + 32] & mask_lsb_8;
        uchar x5 = b->qh[i*4 + 1 + 32] & mask_lsb_8;
        uchar x6 = b->qh[i*4 + 2 + 32] & mask_lsb_8;
        uchar x7 = b->qh[i*4 + 3 + 32] & mask_lsb_8;
        qh[i + 32] = (x4 & 0x03)        | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6);
        qh[i + 40] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C)        | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4);
        qh[i + 48] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30)        | ((x7 & 0x30) << 2);
        qh[i + 56] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0);
    }

    for (int i = 0; i < QK_K/16; ++i) {
        s[i] = b->scales[i];
    }
}

kernel void kernel_restore_block_q6_K_noshuffle(
    global uchar * src_ql,
    global uchar * src_qh,
    global char  * src_s,
    global half  * src_d,
    global struct block_q6_K * dst,
    uchar          mask_lsb_8,
    ulong          n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
    global uchar * ql = (global uchar *) src_ql + QK_K/2*get_global_id(0);
    global uchar * qh = (global uchar *) src_qh + QK_K/4*get_global_id(0);
    global char  * s  = (global char  *) src_s  + QK_K/16*get_global_id(0);
    global half  * d  = (global half  *) src_d  + get_global_id(0);

    b->d = *d;

    for (int i = 0; i < QK_K/2/4; ++i) {
        uchar x0   = ql[i +  0] & mask_lsb_8;
        uchar x1   = ql[i + 32] & mask_lsb_8;
        b->ql[i*2 + 0] = (x0 & 0x0F)        | ((x1 & 0x0F) << 4);
        b->ql[i*2 + 1] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0);

        uchar x2   = ql[i + 64] & mask_lsb_8;
        uchar x3   = ql[i + 96] & mask_lsb_8;
        b->ql[i*2 + 0 + 64] = (x2 & 0x0F)        | ((x3 & 0x0F) << 4);
        b->ql[i*2 + 1 + 64] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0);
    }

    for (int i = 0; i < QK_K/4/8; ++i) {
        uchar x0 = qh[i +  0] & mask_lsb_8;
        uchar x1 = qh[i +  8] & mask_lsb_8;
        uchar x2 = qh[i + 16] & mask_lsb_8;
        uchar x3 = qh[i + 24] & mask_lsb_8;
        b->qh[i*4 + 0] = (x0 & 0x03)        | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6);
        b->qh[i*4 + 1] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C)        | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4);
        b->qh[i*4 + 2] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30)        | ((x3 & 0x30) << 2);
        b->qh[i*4 + 3] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0);

        uchar x4 = qh[i +  0 + 32] & mask_lsb_8;
        uchar x5 = qh[i +  8 + 32] & mask_lsb_8;
        uchar x6 = qh[i + 16 + 32] & mask_lsb_8;
        uchar x7 = qh[i + 24 + 32] & mask_lsb_8;
        b->qh[i*4 + 0 + 32] = (x4 & 0x03)        | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6);
        b->qh[i*4 + 1 + 32] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C)        | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4);
        b->qh[i*4 + 2 + 32] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30)        | ((x7 & 0x30) << 2);
        b->qh[i*4 + 3 + 32] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0);
    }

    for (int i = 0; i < QK_K/16; ++i) {
        b->scales[i] = s[i];
    }
}

//------------------------------------------------------------------------------
// kernel_convert_block_iq4_nl
// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA).
//------------------------------------------------------------------------------
kernel void kernel_convert_block_iq4_nl(
    global struct block_iq4_nl * src0,
    global uchar * dst_q,
    global half  * dst_d,
    uchar          mask_0F,
    uchar          mask_F0,
    ulong          n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
    global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
    global half  * d = (global half *) dst_d + get_global_id(0);

    *d = b->d;

    for (int i = 0; i < QK4_NL/2; ++i) {
        q[i] = b->qs[i];
    }
}

kernel void kernel_restore_block_iq4_nl(
    global uchar * src_q,
    global half  * src_d,
    global struct block_iq4_nl * dst,
    ulong          n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
    global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
    global half  * d = (global half *) src_d + get_global_id(0);

    b->d = *d;

    for (int i = 0; i < QK4_NL/2; ++i) {
        b->qs[i] = q[i];
    }
}

kernel void kernel_convert_block_iq4_nl_noshuffle(
    global struct block_iq4_nl * src0,
    global uchar * dst_q,
    global half  * dst_d,
    uchar          mask_0F,
    uchar          mask_F0,
    ulong          n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
    global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
    global half  * d = (global half *) dst_d + get_global_id(0);

    *d = b->d;
    for (int i = 0; i < QK4_NL/4; ++i) {
        uchar x0 = b->qs[2*i + 0];
        uchar x1 = b->qs[2*i + 1];

        q[i + 0       ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
        q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
    }
}

kernel void kernel_restore_block_iq4_nl_noshuffle(
    global uchar * src_q,
    global half  * src_d,
    global struct block_iq4_nl * dst,
    uchar mask_0F,
    uchar mask_F0,
    ulong n_blk
) {
    if (get_global_id(0) >= n_blk) {
        return;
    }
    global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
    global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
    global half  * d = (global half *) src_d + get_global_id(0);

    b->d = *d;
    for (int i = 0; i < QK4_NL/4; ++i) {
        uchar x0 = q[i + 0       ];
        uchar x1 = q[i + QK4_NL/4];

        b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
        b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
    }
}
