mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-28 17:27:26 +03:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa50b2c2ae | ||
|
|
c40006a62e | ||
|
|
c6e4088376 | ||
|
|
b36eefc1b3 |
@@ -68,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }
|
||||
static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
|
||||
static int opt_opbatch = 1024; // max number of ops in a batch
|
||||
static int opt_opqueue = 16; // max number of pending batches
|
||||
static int opt_oppoll = 0; // polling for batch completions
|
||||
|
||||
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
|
||||
|
||||
@@ -550,7 +551,7 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
|
||||
|
||||
// Ensure we don't try to read more data than is available in the source buffer 'data'
|
||||
// or write more than the tensor can hold.
|
||||
@@ -611,7 +612,7 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
|
||||
|
||||
// Ensure we don't try to copy more data than the tensor actually contains.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
@@ -660,6 +661,239 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
|
||||
static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) {
|
||||
static const int qk = QK4_1;
|
||||
|
||||
for (unsigned int i = 0; i < qk / 2; ++i) {
|
||||
const int x0 = (x->qs[i] & 0x0F);
|
||||
const int x1 = (x->qs[i] >> 4);
|
||||
qs[bi * qk + i + 0] = x0;
|
||||
qs[bi * qk + i + qk / 2] = x1;
|
||||
}
|
||||
}
|
||||
|
||||
static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) {
|
||||
static const int qk = QK4_1;
|
||||
|
||||
for (unsigned int i = 0; i < qk / 2; ++i) {
|
||||
const uint8_t x0 = qs[bi * qk + i + 0];
|
||||
const uint8_t x1 = qs[bi * qk + i + qk / 2];
|
||||
x->qs[i] = x0 | (x1 << 4);
|
||||
}
|
||||
}
|
||||
|
||||
static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) {
|
||||
static const int qk = QK_Q4_0x4x2;
|
||||
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
|
||||
const int nloe = k % qk; // leftovers
|
||||
|
||||
const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes
|
||||
const int qblk_size = qk / 2; // int4 = 128 bytes
|
||||
const int qrow_size = k / 2; // int4 (not padded to blocks)
|
||||
|
||||
uint8_t * y_q = y + 0; // quants first
|
||||
uint8_t * y_d = y + qrow_size; // then scales/offsets
|
||||
|
||||
// Repack the quants
|
||||
for (int i = 0; i < nb; i++) {
|
||||
uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 0], 0);
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 1], 1);
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 2], 2);
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 3], 3);
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 4], 4);
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 5], 5);
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 6], 6);
|
||||
unpack_q4_1_quants(qs, &x[i * 8 + 7], 7);
|
||||
|
||||
bool partial = (nloe && i == nb-1);
|
||||
|
||||
uint8_t * q = y_q + (i * qblk_size);
|
||||
for (int j = 0; j < qk / 2; j++) {
|
||||
q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];
|
||||
}
|
||||
}
|
||||
|
||||
// Repack the scales and offsets
|
||||
for (int i = 0; i < nb; i++) {
|
||||
ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size);
|
||||
for (int j = 0; j < 8; j++) {
|
||||
d_m[j * 2 + 0] = x[i * 8 + j].d;
|
||||
d_m[j * 2 + 1] = x[i * 8 + j].m;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) {
|
||||
static const int qk = QK_Q4_0x4x2;
|
||||
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
|
||||
const int nloe = k % qk; // leftovers
|
||||
|
||||
const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes
|
||||
const int qblk_size = qk / 2; // int4 = 128 bytes
|
||||
const int qrow_size = k / 2; // int4 (not padded to blocks)
|
||||
|
||||
const uint8_t * y_q = y + 0; // quants first
|
||||
const uint8_t * y_d = y + qrow_size; // then scales/offsets
|
||||
|
||||
// Unpack the quants
|
||||
for (int i = 0; i < nb; i++) {
|
||||
uint8_t qs[QK_Q4_0x4x2];
|
||||
bool partial = (nloe && i == nb-1);
|
||||
|
||||
const uint8_t * q = y_q + (i * qblk_size);
|
||||
for (int j = 0; j < qk / 2; j++) {
|
||||
if (partial) {
|
||||
qs[j*2+0] = q[j] & 0x0F;
|
||||
qs[j*2+1] = q[j] >> 4;
|
||||
} else {
|
||||
qs[j+000] = q[j] & 0x0F;
|
||||
qs[j+128] = q[j] >> 4;
|
||||
}
|
||||
}
|
||||
|
||||
pack_q4_1_quants(&x[i * 8 + 0], qs, 0);
|
||||
pack_q4_1_quants(&x[i * 8 + 1], qs, 1);
|
||||
pack_q4_1_quants(&x[i * 8 + 2], qs, 2);
|
||||
pack_q4_1_quants(&x[i * 8 + 3], qs, 3);
|
||||
pack_q4_1_quants(&x[i * 8 + 4], qs, 4);
|
||||
pack_q4_1_quants(&x[i * 8 + 5], qs, 5);
|
||||
pack_q4_1_quants(&x[i * 8 + 6], qs, 6);
|
||||
pack_q4_1_quants(&x[i * 8 + 7], qs, 7);
|
||||
}
|
||||
|
||||
// Unpack the scales and offsets
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size);
|
||||
for (int j = 0; j < 8; j++) {
|
||||
x[i * 8 + j].d = d_m[j * 2 + 0];
|
||||
x[i * 8 + j].m = d_m[j * 2 + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) {
|
||||
static const int qk = QK_Q4_0x4x2;
|
||||
const int nb = (k + qk - 1) / qk; // number of blocks (padded)
|
||||
|
||||
uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
|
||||
memset(qs, 0, sizeof(qs));
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
pack_q4_1_quants(&x[i * 8 + 0], qs, 0);
|
||||
pack_q4_1_quants(&x[i * 8 + 1], qs, 1);
|
||||
pack_q4_1_quants(&x[i * 8 + 2], qs, 2);
|
||||
pack_q4_1_quants(&x[i * 8 + 3], qs, 3);
|
||||
pack_q4_1_quants(&x[i * 8 + 4], qs, 4);
|
||||
pack_q4_1_quants(&x[i * 8 + 5], qs, 5);
|
||||
pack_q4_1_quants(&x[i * 8 + 6], qs, 6);
|
||||
pack_q4_1_quants(&x[i * 8 + 7], qs, 7);
|
||||
}
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
x[i * 8 + j].d = 0;
|
||||
x[i * 8 + j].m = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) {
|
||||
int64_t nrows = ggml_nrows(t);
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
|
||||
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
void * buf_rp = ggml_aligned_malloc(row_size_rp);
|
||||
GGML_ASSERT(buf_rp != NULL);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
|
||||
t->ne[0], nrows, row_size);
|
||||
|
||||
init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]);
|
||||
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
memcpy(buf_pd, src, row_size);
|
||||
repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]);
|
||||
memcpy(dst, buf_rp, row_size);
|
||||
}
|
||||
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) t->data + (i * row_size);
|
||||
|
||||
init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]);
|
||||
memcpy(buf_pd, src, n_rem_bytes);
|
||||
repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]);
|
||||
memcpy(dst, buf_rp, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
|
||||
static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) {
|
||||
int64_t nrows = ggml_nrows(t);
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
|
||||
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
|
||||
|
||||
const int64_t n_full_rows = n_bytes_to_copy / row_size;
|
||||
const size_t n_rem_bytes = n_bytes_to_copy % row_size;
|
||||
|
||||
void * buf_pd = ggml_aligned_malloc(row_size_pd);
|
||||
GGML_ASSERT(buf_pd != NULL);
|
||||
|
||||
void * buf_rp = ggml_aligned_malloc(row_size_rp);
|
||||
GGML_ASSERT(buf_rp != NULL);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
|
||||
t->ne[0], nrows, row_size);
|
||||
|
||||
memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros
|
||||
|
||||
for (int64_t i = 0; i < n_full_rows; i++) {
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
memcpy(buf_rp, src, row_size);
|
||||
unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]);
|
||||
memcpy(dst, buf_pd, row_size);
|
||||
}
|
||||
|
||||
if (n_rem_bytes > 0) {
|
||||
const int64_t i = n_full_rows;
|
||||
const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
|
||||
uint8_t * dst = (uint8_t *) data + (i * row_size);
|
||||
|
||||
// We still need to read and unpack the entire source row because quantization is block-based.
|
||||
memcpy(buf_rp, src, row_size);
|
||||
unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]);
|
||||
memcpy(dst, buf_pd, n_rem_bytes);
|
||||
}
|
||||
|
||||
ggml_aligned_free(buf_pd, row_size_pd);
|
||||
ggml_aligned_free(buf_rp, row_size_rp);
|
||||
}
|
||||
|
||||
// ======== Q8x4x2 ====================
|
||||
static void dump_block_q8_0(const block_q8_0 * b, int i) {
|
||||
HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2],
|
||||
@@ -876,7 +1110,7 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales)
|
||||
|
||||
// Ensure we don't try to read more data than is available in the source buffer 'data'
|
||||
// or write more than the tensor can hold.
|
||||
@@ -937,7 +1171,7 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales)
|
||||
|
||||
// Ensure we don't try to copy more data than the tensor actually contains.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
@@ -1238,7 +1472,7 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
|
||||
|
||||
// Ensure we don't try to read more data than is available in the source buffer 'data'
|
||||
// or write more than the tensor can hold.
|
||||
@@ -1299,7 +1533,7 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
|
||||
|
||||
size_t row_size = ggml_row_size(t->type, t->ne[0]);
|
||||
size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
|
||||
size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
|
||||
size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales)
|
||||
|
||||
// Ensure we don't try to copy more data than the tensor actually contains.
|
||||
const size_t total_tensor_size = (size_t)nrows * row_size;
|
||||
@@ -1365,6 +1599,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
repack_q4_0_q4x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_Q4_1:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q4_1_q4x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_Q8_0:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
@@ -1407,6 +1647,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
repack_q4x4x2_q4_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_Q4_1:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q4x4x2_q4_1(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_Q8_0:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
@@ -1886,7 +2132,8 @@ void ggml_hexagon_session::flush_pending(bool all) {
|
||||
uint32_t n_dbufs;
|
||||
|
||||
// Read response packet from queue
|
||||
int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT);
|
||||
const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT;
|
||||
int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo);
|
||||
if (err == AEE_EEXPIRED) {
|
||||
continue;
|
||||
}
|
||||
@@ -2327,6 +2574,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
@@ -2377,6 +2625,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
@@ -3622,6 +3871,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
// Basic sanity checks to make sure definitions match
|
||||
static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
|
||||
@@ -3634,6 +3885,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE");
|
||||
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
|
||||
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
|
||||
const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL");
|
||||
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
@@ -3671,6 +3923,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage;
|
||||
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
|
||||
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
|
||||
opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll;
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
|
||||
@@ -59,14 +59,14 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
|
||||
@@ -34,6 +34,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
|
||||
};
|
||||
|
||||
static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0,
|
||||
};
|
||||
|
||||
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
|
||||
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||||
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
@@ -62,6 +66,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||||
case HTP_TYPE_Q4_1:
|
||||
return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb
|
||||
case HTP_TYPE_Q8_0:
|
||||
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||||
case HTP_TYPE_MXFP4:
|
||||
@@ -233,6 +239,54 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
return r;
|
||||
}
|
||||
|
||||
static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_dm = hvx_vmemu(scale_offset);
|
||||
HVX_Vector v_scales = hvx_vec_repl_f16(v_dm);
|
||||
HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2));
|
||||
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets));
|
||||
}
|
||||
|
||||
static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx(
|
||||
const uint8_t *packed_128, bool upper_nibbles,
|
||||
const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
|
||||
HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4);
|
||||
HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2);
|
||||
HVX_Vector vd = Q6_V_lo_W(dm_deal);
|
||||
HVX_Vector vm = Q6_V_hi_W(dm_deal);
|
||||
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4));
|
||||
|
||||
HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm);
|
||||
HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4));
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23));
|
||||
|
||||
HVX_Vector_x2 r = { v_lo, v_hi };
|
||||
return r;
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
|
||||
HVX_Vector vq = hvx_vmemu(quants_32);
|
||||
@@ -331,11 +385,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1);
|
||||
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) :
|
||||
hvx_vmem(q4_0_to_fp16_lut);
|
||||
|
||||
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
|
||||
@@ -356,8 +412,10 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||||
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
|
||||
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
|
||||
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
|
||||
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
|
||||
unsigned scale_off = qrow_size + blk_idx * dblk_size
|
||||
+ sub_blk_base * scale_step;
|
||||
|
||||
__fp16 *tile_bases[4];
|
||||
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||||
@@ -367,20 +425,38 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
if (is_q4_1) {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
|
||||
@@ -446,26 +522,43 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||||
unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE;
|
||||
unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16);
|
||||
unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step;
|
||||
|
||||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
if (is_q4_1) {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
|
||||
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q4_0_group_hvx(
|
||||
r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||||
: Q6_V_vzero();
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||||
: Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
} else {
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt)
|
||||
: Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
} else if (weight_type == HTP_TYPE_MXFP4) {
|
||||
@@ -593,6 +686,8 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
|
||||
// --- End x4x2 dequantizers ---
|
||||
|
||||
#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics
|
||||
|
||||
// requires external HMX lock
|
||||
static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales,
|
||||
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {
|
||||
|
||||
@@ -20,6 +20,7 @@ enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q4_1 = 3,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_IQ4_NL = 20,
|
||||
HTP_TYPE_I32 = 26,
|
||||
@@ -28,6 +29,7 @@ enum htp_data_type {
|
||||
|
||||
// types used internally for repack, dyn.quant, etc
|
||||
HTP_TYPE_Q4_0x4x2 = 200,
|
||||
HTP_TYPE_Q4_1x4x2,
|
||||
HTP_TYPE_Q8_0x4x2,
|
||||
HTP_TYPE_MXFP4x4x2,
|
||||
|
||||
|
||||
@@ -853,6 +853,11 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
for (uint32_t i=0; i < n_ops; i++) {
|
||||
struct profile_data prof;
|
||||
|
||||
if (i == (n_ops-1)) {
|
||||
// wake up the host before starting the last op
|
||||
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
|
||||
}
|
||||
|
||||
profile_start(ctx->profiler, &prof);
|
||||
|
||||
proc_op_req(octx, tens, i, &ops[i]);
|
||||
@@ -869,8 +874,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
}
|
||||
}
|
||||
|
||||
// dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);
|
||||
|
||||
struct htp_opbatch_rsp rsp;
|
||||
rsp.id = req.id;
|
||||
rsp.status = HTP_STATUS_OK;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -79,6 +79,12 @@ if (Vulkan_FOUND)
|
||||
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
test_shader_extension_support(
|
||||
"GL_NV_cooperative_matrix_decode_vector"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp"
|
||||
"GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_integer_dot_product"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"
|
||||
|
||||
@@ -21,6 +21,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||
|
||||
#include <vulkan/vulkan.hpp>
|
||||
|
||||
// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the
|
||||
// installed Vulkan headers predate the extension.
|
||||
#ifndef VK_NV_cooperative_matrix_decode_vector
|
||||
#define VK_NV_cooperative_matrix_decode_vector 1
|
||||
#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector"
|
||||
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000)
|
||||
typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
|
||||
VkStructureType sType;
|
||||
void* pNext;
|
||||
VkBool32 cooperativeMatrixDecodeVector;
|
||||
} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV;
|
||||
#endif
|
||||
|
||||
// SPIR-V Headers: different SDK installations expose different include paths.
|
||||
// LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>.
|
||||
// Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>.
|
||||
@@ -678,6 +691,7 @@ struct vk_device_struct {
|
||||
uint32_t coopmat_int_k;
|
||||
|
||||
bool coopmat2;
|
||||
bool coopmat2_decode_vector;
|
||||
|
||||
bool pipeline_executable_properties_support {};
|
||||
|
||||
@@ -2167,6 +2181,136 @@ static uint32_t compile_count = 0;
|
||||
static std::mutex compile_count_mutex;
|
||||
static std::condition_variable compile_count_cond;
|
||||
|
||||
static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
|
||||
static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
|
||||
static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
|
||||
|
||||
// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it
|
||||
// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the
|
||||
// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the
|
||||
// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction.
|
||||
// Returns true when the input used the extension (and `out` was populated with a
|
||||
// stripped copy); returns false otherwise without touching `out`.
|
||||
static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
|
||||
static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector";
|
||||
|
||||
if (word_count < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool uses_decode_vector = false;
|
||||
for (size_t pos = 5; pos < word_count; ) {
|
||||
uint32_t word = code[pos];
|
||||
uint32_t wc = word >> spv::WordCountShift;
|
||||
uint32_t op = word & spv::OpCodeMask;
|
||||
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
|
||||
if (op == spv::OpExtension && wc >= 2) {
|
||||
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
|
||||
if (strcmp(s, kDecodeVectorExt) == 0) {
|
||||
uses_decode_vector = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
pos += wc;
|
||||
}
|
||||
|
||||
if (!uses_decode_vector) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector");
|
||||
|
||||
// Bulk-copy unchanged runs and only break the run when an instruction needs to
|
||||
// be dropped or patched. Use reserve + insert/push_back so the destination buffer
|
||||
// is touched exactly once (no zero-initialization pass from resize()).
|
||||
out.clear();
|
||||
out.reserve(word_count);
|
||||
|
||||
size_t run_start = 0;
|
||||
auto flush_run = [&](size_t up_to) {
|
||||
if (up_to > run_start) {
|
||||
out.insert(out.end(), code + run_start, code + up_to);
|
||||
}
|
||||
};
|
||||
|
||||
for (size_t pos = 5; pos < word_count; ) {
|
||||
uint32_t word = code[pos];
|
||||
uint32_t wc = word >> spv::WordCountShift;
|
||||
uint32_t op = word & spv::OpCodeMask;
|
||||
GGML_ASSERT(wc > 0 && pos + wc <= word_count);
|
||||
|
||||
if (op == spv::OpExtension && wc >= 2) {
|
||||
const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
|
||||
if (strcmp(s, kDecodeVectorExt) == 0) {
|
||||
flush_run(pos);
|
||||
pos += wc;
|
||||
run_start = pos;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) {
|
||||
flush_run(pos);
|
||||
pos += wc;
|
||||
run_start = pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op == kSpvOpCooperativeMatrixLoadTensorNV) {
|
||||
// [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...]
|
||||
GGML_ASSERT(wc >= 8);
|
||||
|
||||
uint32_t mem_mask = code[pos + 6];
|
||||
size_t cur = pos + 7;
|
||||
// Each of these MemoryAccess bits (when set) carries one trailing operand.
|
||||
cur += (mem_mask & 0x2) ? 1 : 0; // Aligned
|
||||
cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable
|
||||
cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible
|
||||
cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask
|
||||
cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask
|
||||
GGML_ASSERT(cur < pos + wc);
|
||||
|
||||
uint32_t ta_mask = code[cur];
|
||||
if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) {
|
||||
pos += wc;
|
||||
continue; // leave instruction inside the current unchanged run
|
||||
}
|
||||
|
||||
flush_run(pos);
|
||||
|
||||
// Append unchanged prefix of the instruction (header through the mem-extras).
|
||||
size_t inst_start = out.size();
|
||||
size_t pre_n = cur - pos;
|
||||
out.insert(out.end(), code + pos, code + pos + pre_n);
|
||||
|
||||
// Emit TA mask with the DecodeVectorFunc bit cleared.
|
||||
out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit);
|
||||
|
||||
// TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim;
|
||||
// DecodeVectorFunc (0x4) is dropped along with its trailing id operand.
|
||||
size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0);
|
||||
if (keep_ta_extras) {
|
||||
out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras);
|
||||
}
|
||||
|
||||
GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1);
|
||||
|
||||
// Patch the instruction header with the new (one-shorter) word count.
|
||||
uint32_t new_wc = wc - 1;
|
||||
out[inst_start] = (new_wc << spv::WordCountShift) | op;
|
||||
|
||||
pos += wc;
|
||||
run_start = pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
pos += wc;
|
||||
}
|
||||
|
||||
flush_run(word_count);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
|
||||
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
|
||||
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
|
||||
@@ -2238,6 +2382,18 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
if (device->coopmat2 && !device->coopmat2_decode_vector) {
|
||||
const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
|
||||
size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
|
||||
std::vector<uint32_t> stripped;
|
||||
if (ggml_vk_strip_decode_vector(src, src_n, stripped)) {
|
||||
spirv = std::move(stripped);
|
||||
shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
|
||||
|
||||
vk::PushConstantRange pcr(
|
||||
@@ -5159,6 +5315,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
bool amd_shader_core_properties2 = false;
|
||||
bool pipeline_robustness = false;
|
||||
bool coopmat2_support = false;
|
||||
bool coopmat2_decode_vector_support = false;
|
||||
bool pipeline_executable_properties_support = false;
|
||||
device->coopmat_support = false;
|
||||
device->integer_dot_product = false;
|
||||
@@ -5193,6 +5350,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
||||
coopmat2_support = true;
|
||||
#endif
|
||||
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
|
||||
coopmat2_decode_vector_support = true;
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||
@@ -5470,6 +5630,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
|
||||
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
|
||||
if (coopmat2_decode_vector_support) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME);
|
||||
}
|
||||
|
||||
#if defined(VK_KHR_shader_bfloat16)
|
||||
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
||||
bfloat16_features.pNext = nullptr;
|
||||
@@ -5629,6 +5797,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
found_fp32_128 && found_fp32_256 &&
|
||||
coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
|
||||
device->coopmat2 = true;
|
||||
device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -5915,6 +6084,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
bool fp16_compute = false;
|
||||
bool coopmat_support = false;
|
||||
bool coopmat2_support = false;
|
||||
bool coopmat2_decode_vector_support = false;
|
||||
bool integer_dot_product = false;
|
||||
bool bfloat16_support = false;
|
||||
|
||||
@@ -5933,6 +6103,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
||||
coopmat2_support = true;
|
||||
#endif
|
||||
} else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
|
||||
coopmat2_decode_vector_support = true;
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||
@@ -6017,6 +6190,13 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
|
||||
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
|
||||
if (coopmat2_decode_vector_support) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
||||
|
||||
fp16 = fp16 && vk12_features.shaderFloat16;
|
||||
@@ -6041,7 +6221,14 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
#endif
|
||||
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
|
||||
|
||||
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
||||
coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
|
||||
#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
coopmat2_decode_vector_support = false;
|
||||
#endif
|
||||
|
||||
std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2")
|
||||
: coopmat_support ? "KHR_coopmat"
|
||||
: "none";
|
||||
|
||||
std::string device_name = props2.properties.deviceName.data();
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
||||
|
||||
@@ -11,6 +11,10 @@ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat2 glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling coopmat2 decode_vector glslc support")
|
||||
endif()
|
||||
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
message(STATUS "Enabling dot glslc support")
|
||||
|
||||
@@ -5,21 +5,60 @@
|
||||
#include "types.glsl"
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
|
||||
return data_a[a_offset + ib];
|
||||
}
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
|
||||
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
|
||||
}
|
||||
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
|
||||
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_F16)
|
||||
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
|
||||
return data_a[a_offset + ib];
|
||||
}
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1],
|
||||
data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]);
|
||||
}
|
||||
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
|
||||
const vec2 a = data_a_packed32[(a_offset + ib)/2];
|
||||
const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1];
|
||||
return vec4(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) {
|
||||
return bf16_to_fp32(data_a[a_offset + ib]);
|
||||
}
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]),
|
||||
bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3]));
|
||||
}
|
||||
vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) {
|
||||
const uint a = data_a_packed32[(a_offset + ib)/2];
|
||||
const uint b = data_a_packed32[(a_offset + ib)/2 + 1];
|
||||
return vec4(uintBitsToFloat((a & 0x0000ffff) << 16),
|
||||
uintBitsToFloat( a & 0xffff0000),
|
||||
uintBitsToFloat((b & 0x0000ffff) << 16),
|
||||
uintBitsToFloat( b & 0xffff0000));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
|
||||
@@ -1,4 +1,12 @@
|
||||
|
||||
// Each format defines a scalar dequantFunc<T> plus a V=4 dequantFunc<T>_v
|
||||
// passed as the optional vector decoder to coopMatLoadTensorNV via
|
||||
// GL_NV_cooperative_matrix_decode_vector. When the driver doesn't support
|
||||
// the extension, ggml-vulkan.cpp strips it from the compiled SPIR-V.
|
||||
#ifdef GL_NV_cooperative_matrix_decode_vector
|
||||
#extension GL_NV_cooperative_matrix_decode_vector : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
|
||||
@@ -25,6 +33,19 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2
|
||||
return bit != 0u ? d : -d;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ1_0_v(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t md = -d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint qs_nib = uint(bl.block.qs[idx >> 3]) >> (idx & 0x4u);
|
||||
return f16vec4(
|
||||
(qs_nib & 1u) != 0u ? d : md,
|
||||
(qs_nib & 2u) != 0u ? d : md,
|
||||
(qs_nib & 4u) != 0u ? d : md,
|
||||
(qs_nib & 8u) != 0u ? d : md);
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
||||
block_q4_0_packed16 block;
|
||||
};
|
||||
@@ -42,10 +63,28 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ4_0_v(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_i = (idx & 0xE) >> 1; // even, in {0,2,4,6}
|
||||
const uint qsw = uint32_t(bl.block.qs[qs_i ])
|
||||
| (uint32_t(bl.block.qs[qs_i + 1u]) << 16);
|
||||
// shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte.
|
||||
const uint q4 = (qsw >> shift) & 0x0F0F0F0Fu;
|
||||
const u8vec4 q = unpack8(q4);
|
||||
return f16vec4((vec4(q) - vec4(8.0)) * vec4(float(d)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
|
||||
block_q4_1 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1_packed32 {
|
||||
block_q4_1_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
@@ -60,10 +99,27 @@ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ4_1_v(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ4_1_packed32 bl32 = decodeBufQ4_1_packed32(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t m = bl.block.m;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4)
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
return f16vec4(vec4(q) * vec4(float(d)) + vec4(float(m)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
|
||||
block_q5_0 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0_packed16 {
|
||||
block_q5_0_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
@@ -82,10 +138,32 @@ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ5_0_v(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_0_packed16 bl16 = decodeBufQ5_0_packed16(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6}
|
||||
const uint qsw = uint32_t(bl16.block.qs[qs_i ])
|
||||
| (uint32_t(bl16.block.qs[qs_i + 1u]) << 16);
|
||||
const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
|
||||
const uint uint_qh = uint(bl16.block.qh[1]) << 16 | uint(bl16.block.qh[0]);
|
||||
const uint qh_pack = uint_qh >> idx; // bits 0..3 = element idx..idx+3 high bits
|
||||
const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u;
|
||||
|
||||
return f16vec4((vec4(ql) + vec4(qh_high) - vec4(16.0)) * vec4(float(d)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
|
||||
block_q5_1 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1_packed32 {
|
||||
block_q5_1_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
@@ -105,6 +183,23 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ5_1_v(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_1_packed32 bl32 = decodeBufQ5_1_packed32(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t m = bl.block.m;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4)
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
|
||||
const uint qh_pack = bl.block.qh >> idx; // bits 0..3 = element idx..idx+3 high bits
|
||||
const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u;
|
||||
|
||||
return f16vec4((vec4(ql) + vec4(qh_high)) * vec4(float(d)) + vec4(float(m)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
|
||||
block_q8_0_packed16 block;
|
||||
};
|
||||
@@ -121,6 +216,17 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ8_0_v(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint base = idx >> 1u;
|
||||
const uint w = uint(uint16_t(bl.block.qs[base]))
|
||||
| (uint(uint16_t(bl.block.qs[base + 1u])) << 16u);
|
||||
const i8vec4 qi = unpack8(int32_t(w));
|
||||
return f16vec4(vec4(qi) * vec4(float(d)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
|
||||
block_q2_K block;
|
||||
};
|
||||
@@ -129,6 +235,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
|
||||
block_q2_K_packed16 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K_packed32 {
|
||||
block_q2_K_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
|
||||
@@ -147,10 +257,36 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ2_K_v(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ2_K_packed32 bl32 = decodeBufQ2_K_packed32(bl);
|
||||
const f16vec2 dm = bl.block.dm;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint scalesi = idx >> 4; // 0..15
|
||||
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
|
||||
|
||||
// qs_i (packed16) = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1) is even for idx % 4 == 0,
|
||||
// so qs_w (packed32) = qs_i / 2 = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2).
|
||||
const uint qs_w = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2);
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const uint qs4 = (qsw >> qsshift) & 0x03030303u;
|
||||
const u8vec4 qi = unpack8(qs4);
|
||||
|
||||
const uint scales = bl.block.scales[scalesi];
|
||||
const float16_t d_sub = dm.x * float16_t(scales & 0xF);
|
||||
const float16_t m_sub = dm.y * float16_t(scales >> 4);
|
||||
return f16vec4(vec4(qi) * vec4(float(d_sub)) - vec4(float(m_sub)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
|
||||
block_q3_K block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K_packed16 {
|
||||
block_q3_K_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
@@ -179,6 +315,47 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ3_K_v(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ3_K_packed16 bl16 = decodeBufQ3_K_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint n = idx >> 7; // 0,1
|
||||
const uint is = idx >> 4; // 0..15
|
||||
const uint halfsplit = (idx & 0x60) >> 5; // 0,1,2,3
|
||||
const uint qsshift = halfsplit << 1; // 0,2,4,6
|
||||
const uint hbit = (n << 2) + halfsplit; // 0..7 (bit position in hmask byte)
|
||||
|
||||
uint32_t scaleidx0 = (is < 8) ? is : (is - 8);
|
||||
uint32_t scaleidx0shift = (is < 8) ? 0u : 4u;
|
||||
uint32_t scaleidx1 = is + 8 - (is / 4) * 4;
|
||||
uint32_t scaleidx1shift = (is / 4) * 2;
|
||||
|
||||
const int8_t us = int8_t(
|
||||
((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) |
|
||||
(((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
|
||||
const float16_t dl = bl.block.d * float16_t(int(us) - 32);
|
||||
|
||||
// For idx % 4 == 0: (idx & 0x1F) == (idx & 0x1C) is a multiple of 4.
|
||||
const uint qsi = (n << 5) + (idx & 0x1Cu);
|
||||
const uint hmi = (idx & 0x1Cu);
|
||||
|
||||
// Two adjacent uint16 packed16 reads, combined into a uint32 in registers.
|
||||
// After this: byte j of qsw / hmw holds the data for element idx+j.
|
||||
const uint qsw = uint32_t(bl16.block.qs[qsi >> 1])
|
||||
| (uint32_t(bl16.block.qs[(qsi >> 1) + 1u]) << 16);
|
||||
const uint hmw = uint32_t(bl16.block.hmask[hmi >> 1])
|
||||
| (uint32_t(bl16.block.hmask[(hmi >> 1) + 1u]) << 16);
|
||||
|
||||
// qsshift in {0,2,4,6} and hbit in {0..7}: per-byte masks isolate the wanted bits
|
||||
// with no inter-byte leakage.
|
||||
const uint ql4 = (qsw >> qsshift) & 0x03030303u;
|
||||
const uint qh4 = (hmw >> hbit) & 0x01010101u;
|
||||
|
||||
const ivec4 q = ivec4(unpack8(ql4 | (qh4 << 2))) - ivec4(4);
|
||||
return f16vec4(vec4(q) * vec4(float(dl)));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
|
||||
block_q4_K block;
|
||||
};
|
||||
@@ -187,6 +364,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
|
||||
block_q4_K_packed16 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed32 {
|
||||
block_q4_K_packed32 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
|
||||
block_q4_K_packed128 block;
|
||||
};
|
||||
@@ -334,6 +515,55 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||
return float16_t(ret);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ4_K_v(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ4_K_packed32 bl32 = decodeBufQ4_K_packed32(bl);
|
||||
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint is = idx >> 5; // 0..7
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q4k[0];
|
||||
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scale0 = v.y;
|
||||
uint32_t scale4 = v.z;
|
||||
uint32_t scale8 = v.w;
|
||||
|
||||
uint32_t sc_lo = scale0;
|
||||
uint32_t mb_lo = scale4;
|
||||
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
||||
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
||||
|
||||
sc = is < 4 ? sc_lo : sc_hi;
|
||||
mbyte = is < 4 ? mb_lo : mb_hi;
|
||||
sc = sc >> (8 * (is & 3));
|
||||
mbyte = mbyte >> (8 * (is & 3));
|
||||
sc &= 0x3F;
|
||||
mbyte &= 0x3F;
|
||||
|
||||
const float d = loadd.x * float(sc);
|
||||
const float m = loadd.y * float(mbyte);
|
||||
#endif
|
||||
|
||||
// idx in [0,256); vector decode uses idx a multiple of 4. packed32 word index:
|
||||
// (qs_i >> 1) == (idx >> 6) * 8 + ((idx & 0x1E) >> 2). sh is 0 or 4 only, so a
|
||||
// single (w >> sh) & 0x0F0F0F0F isolates all four nibbles without inter-byte leakage.
|
||||
const uint sh = (idx & 0x20u) >> 3u;
|
||||
const uint w = uint32_t(bl32.block.qs[(idx >> 6) * 8u + ((idx & 0x1Eu) >> 2)]);
|
||||
const u8vec4 q = unpack8((w >> sh) & 0x0F0F0F0Fu);
|
||||
|
||||
return f16vec4(vec4(d) * vec4(q) - vec4(m));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
|
||||
block_q5_K block;
|
||||
};
|
||||
@@ -346,6 +576,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5
|
||||
block_q5_K_packed128 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed32 {
|
||||
block_q5_K_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
|
||||
@@ -399,6 +633,58 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
||||
return float16_t(ret);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ5_K_v(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ5_K_packed32 bl32 = decodeBufQ5_K_packed32(bl);
|
||||
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint is = idx >> 5;
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q5k[0];
|
||||
|
||||
const f16vec2 loadd = unpackFloat2x16(v.x);
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scale0 = v.y;
|
||||
uint32_t scale4 = v.z;
|
||||
uint32_t scale8 = v.w;
|
||||
|
||||
uint32_t sc_lo = scale0;
|
||||
uint32_t mb_lo = scale4;
|
||||
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
||||
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
||||
|
||||
sc = is < 4 ? sc_lo : sc_hi;
|
||||
mbyte = is < 4 ? mb_lo : mb_hi;
|
||||
sc = sc >> (8 * (is & 3));
|
||||
mbyte = mbyte >> (8 * (is & 3));
|
||||
sc &= 0x3F;
|
||||
mbyte &= 0x3F;
|
||||
|
||||
const float16_t d = loadd.x * float16_t(sc);
|
||||
const float16_t m = loadd.y * float16_t(mbyte);
|
||||
#endif
|
||||
|
||||
// sh is 0 or 4; mask 0x0F0F0F0F covers the four nibbles regardless (no inter-byte leakage).
|
||||
const uint sh = (idx & 0x20u) >> 3u;
|
||||
const uint qs_w = (idx >> 6) * 8u + ((idx & 0x1Eu) >> 2);
|
||||
const uint qh_w = (idx & 0x1Eu) >> 2;
|
||||
|
||||
const uint ql4 = (uint32_t(bl32.block.qs[qs_w]) >> sh) & 0x0F0F0F0Fu;
|
||||
// qh stores bit `is` per element across 4 consecutive bytes; one shift+mask handles all 4.
|
||||
const uint qh4 = ((uint32_t(bl32.block.qh[qh_w]) >> is) & 0x01010101u) << 4u;
|
||||
|
||||
const u8vec4 qi = unpack8(ql4 | qh4);
|
||||
return f16vec4(vec4(qi) * vec4(d) - vec4(m));
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
|
||||
block_q6_K block;
|
||||
};
|
||||
@@ -431,6 +717,35 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncQ6_K_v(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint b = (idx & 0x40) >> 6;
|
||||
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
|
||||
const uint is = idx >> 4;
|
||||
const uint sh = b * 4; // 0 or 4
|
||||
|
||||
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
|
||||
|
||||
const uint ql_i = ((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1);
|
||||
const uint qh_i = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1);
|
||||
|
||||
// Two adjacent uint16 packed16 reads, combined into a uint32 in registers.
|
||||
// After this: byte j of qlw / qhw holds the data for element idx+j.
|
||||
const uint qlw = uint32_t(bl16.block.ql[ql_i ]) | (uint32_t(bl16.block.ql[ql_i + 1]) << 16);
|
||||
const uint qhw = uint32_t(bl16.block.qh[qh_i ]) | (uint32_t(bl16.block.qh[qh_i + 1]) << 16);
|
||||
|
||||
// sh in {0,4} and qhshift in {0,2,4,6}: per-byte masks 0x0F / 0x03 keep only the
|
||||
// wanted bits with no inter-byte leakage; place qh's 2 bits at nibble high position.
|
||||
const uint ql4 = (qlw >> sh) & 0x0F0F0F0Fu;
|
||||
const uint qh4 = ((qhw >> qhshift) & 0x03030303u) << 4u;
|
||||
|
||||
const ivec4 qi = ivec4(unpack8(ql4 | qh4));
|
||||
return f16vec4((vec4(qi) - vec4(32.0f)) * vec4(float(dscale)));
|
||||
}
|
||||
|
||||
#if defined(DATA_A_IQ1_S)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
|
||||
block_iq1_s block;
|
||||
@@ -453,6 +768,29 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
|
||||
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ1_S_v(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5;
|
||||
const uint ib8 = idx >> 3;
|
||||
const int i8b = int(idx & 4); // 0 or 4
|
||||
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const float dl = float(d) * float(2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
|
||||
|
||||
const ivec4 q = ivec4(
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 0), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 1), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 2), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 3), 2));
|
||||
return f16vec4((vec4(q) + vec4(delta)) * dl);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ1_M)
|
||||
@@ -485,6 +823,33 @@ float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords
|
||||
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ1_M_v(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
uvec2 scales = unpack32(bl64.block.scales);
|
||||
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
|
||||
|
||||
const uint ib8 = idx >> 3;
|
||||
const uint ib16 = idx >> 4;
|
||||
const int i8b = int(idx & 4); // 0 or 4 -- i8 base for the V=4 group
|
||||
|
||||
const uint sc = bl.block.scales[ib8 / 8];
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = 2.0 * float(bitfieldExtract(sc, 3 * int(ib16 & 3), 3)) + 1.0;
|
||||
const float delta = ((qh & 8u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const uint grid = iq1s_grid[qs | ((qh & 7u) << 8)];
|
||||
|
||||
const ivec4 q = ivec4(
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 0), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 1), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 2), 2),
|
||||
bitfieldExtract(int(grid), 2 * (i8b + 3), 2));
|
||||
return f16vec4((vec4(q) + vec4(delta)) * (float(d) * dl));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
@@ -520,6 +885,33 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
|
||||
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
|
||||
return float16_t(ret[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ2_XXS_v(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5;
|
||||
const uint ib8 = (idx & 0x18) >> 3;
|
||||
const uint iqs = 8 * ib32 + ib8;
|
||||
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
|
||||
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
|
||||
|
||||
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
|
||||
sign |= bitCount(sign) << 7;
|
||||
const uint sb = sign >> (idx & 7u);
|
||||
|
||||
const uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
|
||||
const u8vec4 g = unpack8(g2);
|
||||
|
||||
return f16vec4(
|
||||
dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XS)
|
||||
@@ -548,6 +940,31 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor
|
||||
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
|
||||
return float16_t(ret[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ2_XS_v(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint is = idx >> 5;
|
||||
const uint sshift = (idx & 0x10) >> 2;
|
||||
const uint iqs = idx >> 3;
|
||||
|
||||
const uint16_t qs = bl.block.qs[iqs];
|
||||
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
|
||||
|
||||
uint sign = uint(qs >> 9);
|
||||
sign |= bitCount(sign) << 7;
|
||||
const uint sb = sign >> (idx & 7u);
|
||||
|
||||
const uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
|
||||
const u8vec4 g = unpack8(g2);
|
||||
|
||||
return f16vec4(
|
||||
dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_S)
|
||||
@@ -576,6 +993,32 @@ float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
|
||||
return float16_t(v[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ2_S_v(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5;
|
||||
const uint ib8 = idx >> 3;
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
|
||||
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint sb = uint(bl.block.qs[QUANT_K / 8 + ib8]) >> (idx & 0x6u);
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
|
||||
const uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
|
||||
const u8vec4 g = unpack8(g2);
|
||||
|
||||
return f16vec4(
|
||||
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_XXS)
|
||||
@@ -609,6 +1052,32 @@ float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCo
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
return float16_t(v[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ3_XXS_v(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint iqs = idx >> 2;
|
||||
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint signs = pack32(u16vec2(bl16.block.qs[is/2+0], bl16.block.qs[is/2+1]));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
|
||||
const uint sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sb = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6u);
|
||||
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const u8vec4 g = unpack8(grid);
|
||||
|
||||
return f16vec4(
|
||||
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_S)
|
||||
@@ -635,6 +1104,30 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords
|
||||
|
||||
return float16_t(v[idx & 1]);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ3_S_v(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint iqs = idx >> 2;
|
||||
const uint iqh = idx >> 5;
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint qh = bl.block.qh[iqh];
|
||||
const uint sb = uint(bl.block.signs[iqs / 2]) >> (idx & 0x6u);
|
||||
const uint scale = bl.block.scales[iqs / 16];
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
|
||||
const uint grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const u8vec4 g = unpack8(grid);
|
||||
|
||||
return f16vec4(
|
||||
db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0),
|
||||
db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_XS)
|
||||
@@ -642,6 +1135,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4
|
||||
block_iq4_xs block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufIQ4_XS_packed32 {
|
||||
block_iq4_xs_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
@@ -657,6 +1154,30 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor
|
||||
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ4_XS_v(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ4_XS_packed32 bl32 = decodeBufIQ4_XS_packed32(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = idx >> 5; // 0..7
|
||||
const uint sl = (bl32.block.scales_l >> (4 * ib32)) & 0xF;
|
||||
const uint sh = (uint(bl32.block.scales_h) >> (2 * ib32)) & 0x3;
|
||||
const uint qshift = (idx & 0x10) >> 2; // {0, 4}
|
||||
const uint qs_w = 4 * ib32 + ((idx & 0xC) >> 2); // iqs / 4, in [0,32)
|
||||
|
||||
const float16_t dl = d * float16_t(int(sl | (sh << 4)) - 32);
|
||||
|
||||
const uint qsw = bl32.block.qs[qs_w];
|
||||
const u8vec4 qv = unpack8((qsw >> qshift) & 0x0F0F0F0Fu);
|
||||
const vec4 ret = vec4(
|
||||
float(kvalues_iq4nl[qv.x]),
|
||||
float(kvalues_iq4nl[qv.y]),
|
||||
float(kvalues_iq4nl[qv.z]),
|
||||
float(kvalues_iq4nl[qv.w])) * float(dl);
|
||||
return f16vec4(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
@@ -664,6 +1185,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4
|
||||
block_iq4_nl block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL_packed16 {
|
||||
block_iq4_nl_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
@@ -676,6 +1201,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncIQ4_NL_v(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ4_NL_packed16 bl16 = decodeBufIQ4_NL_packed16(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2; // 0 or 4
|
||||
const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6}
|
||||
const uint qsw = uint32_t(bl16.block.qs[qs_i ])
|
||||
| (uint32_t(bl16.block.qs[qs_i + 1u]) << 16);
|
||||
// shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte.
|
||||
const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
return f16vec4(
|
||||
float(d) * float(kvalues_iq4nl[q.x]),
|
||||
float(d) * float(kvalues_iq4nl[q.y]),
|
||||
float(d) * float(kvalues_iq4nl[q.z]),
|
||||
float(d) * float(kvalues_iq4nl[q.w]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_MXFP4)
|
||||
@@ -695,6 +1238,26 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
|
||||
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
return ret;
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncMXFP4_v(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float d = e8m0_to_fp32(bl.block.e);
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx & 0xF;
|
||||
const uint shift = (idx & 0x10) >> 2;
|
||||
uvec4 qv = uvec4(
|
||||
uint(bl.block.qs[iqs]),
|
||||
uint(bl.block.qs[iqs + 1u]),
|
||||
uint(bl.block.qs[iqs + 2u]),
|
||||
uint(bl.block.qs[iqs + 3u]));
|
||||
qv = (qv >> shift) & 0xFu;
|
||||
const vec4 ret = vec4(
|
||||
float(kvalues_mxfp4[qv.x]),
|
||||
float(kvalues_mxfp4[qv.y]),
|
||||
float(kvalues_mxfp4[qv.z]),
|
||||
float(kvalues_mxfp4[qv.w])) * d * 0.5f;
|
||||
return f16vec4(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
@@ -702,6 +1265,10 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVF
|
||||
block_nvfp4 block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4_packed32 {
|
||||
block_nvfp4_packed32 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
@@ -713,56 +1280,97 @@ float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords
|
||||
qs = (qs >> shift) & 0xF;
|
||||
return float16_t(kvalues_mxfp4[qs] * d * 0.5);
|
||||
}
|
||||
|
||||
f16vec4 dequantFuncNVFP4_v(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufNVFP4_packed32 bl32 = decodeBufNVFP4_packed32(bl);
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint sub = idx >> 4;
|
||||
const uint qs_w = ((idx & 0x30) >> 3) + ((idx & 0x4u) >> 2); // iqs / 4, in [0,8)
|
||||
const uint shift = (idx & 0x8) >> 1;
|
||||
const float d = ue4m3_to_fp32(bl.block.d[sub]);
|
||||
|
||||
const uint qsw = uint32_t(bl32.block.qs[qs_w]);
|
||||
const u8vec4 qv = unpack8((qsw >> shift) & 0x0F0F0F0Fu);
|
||||
const vec4 ret = vec4(
|
||||
float(kvalues_mxfp4[qv.x]),
|
||||
float(kvalues_mxfp4[qv.y]),
|
||||
float(kvalues_mxfp4[qv.z]),
|
||||
float(kvalues_mxfp4[qv.w])) * d * 0.5f;
|
||||
return f16vec4(ret);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q1_0)
|
||||
#define dequantFuncA dequantFuncQ1_0
|
||||
#define dequantFuncA_v dequantFuncQ1_0_v
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
#define dequantFuncA dequantFuncQ4_0
|
||||
#define dequantFuncA_v dequantFuncQ4_0_v
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define dequantFuncA dequantFuncQ4_1
|
||||
#define dequantFuncA_v dequantFuncQ4_1_v
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define dequantFuncA dequantFuncQ5_0
|
||||
#define dequantFuncA_v dequantFuncQ5_0_v
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define dequantFuncA dequantFuncQ5_1
|
||||
#define dequantFuncA_v dequantFuncQ5_1_v
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define dequantFuncA dequantFuncQ8_0
|
||||
#define dequantFuncA_v dequantFuncQ8_0_v
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
#define dequantFuncA dequantFuncQ2_K
|
||||
#define dequantFuncA_v dequantFuncQ2_K_v
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define dequantFuncA dequantFuncQ3_K
|
||||
#define dequantFuncA_v dequantFuncQ3_K_v
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define dequantFuncA dequantFuncQ4_K
|
||||
#define dequantFuncA_v dequantFuncQ4_K_v
|
||||
#define fetch_scales fetch_scalesQ4_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#define dequantFuncA_v dequantFuncQ5_K_v
|
||||
#define fetch_scales fetch_scalesQ5_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#define dequantFuncA_v dequantFuncQ6_K_v
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
#define dequantFuncA dequantFuncIQ1_S
|
||||
#define dequantFuncA_v dequantFuncIQ1_S_v
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
#define dequantFuncA dequantFuncIQ1_M
|
||||
#define dequantFuncA_v dequantFuncIQ1_M_v
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
#define dequantFuncA dequantFuncIQ2_XXS
|
||||
#define dequantFuncA_v dequantFuncIQ2_XXS_v
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
#define dequantFuncA dequantFuncIQ2_XS
|
||||
#define dequantFuncA_v dequantFuncIQ2_XS_v
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
#define dequantFuncA dequantFuncIQ2_S
|
||||
#define dequantFuncA_v dequantFuncIQ2_S_v
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
#define dequantFuncA dequantFuncIQ3_XXS
|
||||
#define dequantFuncA_v dequantFuncIQ3_XXS_v
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
#define dequantFuncA dequantFuncIQ3_S
|
||||
#define dequantFuncA_v dequantFuncIQ3_S_v
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
#define dequantFuncA dequantFuncIQ4_XS
|
||||
#define dequantFuncA_v dequantFuncIQ4_XS_v
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#define dequantFuncA_v dequantFuncIQ4_NL_v
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
#define dequantFuncA dequantFuncMXFP4
|
||||
#define dequantFuncA_v dequantFuncMXFP4_v
|
||||
#elif defined(DATA_A_NVFP4)
|
||||
#define dequantFuncA dequantFuncNVFP4
|
||||
#define dequantFuncA_v dequantFuncNVFP4_v
|
||||
#elif defined(DATA_A_F32)
|
||||
#define dequantFuncA dequantFuncF32
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
#version 460
|
||||
|
||||
#extension GL_NV_cooperative_matrix_decode_vector : require
|
||||
|
||||
void main()
|
||||
{
|
||||
}
|
||||
@@ -11,6 +11,9 @@
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_NV_cooperative_matrix2 : enable
|
||||
#ifdef GL_NV_cooperative_matrix_decode_vector
|
||||
#extension GL_NV_cooperative_matrix_decode_vector : enable
|
||||
#endif
|
||||
#extension GL_EXT_buffer_reference : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
@@ -54,6 +57,41 @@ float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const
|
||||
}
|
||||
}
|
||||
|
||||
// V=4 vector decode for K/V; dispatches to per-format _v decoders.
|
||||
f16vec4 faDecodeKVector(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeK) {
|
||||
case 0u: return f16vec4(decodeBufF32(bl_in).block);
|
||||
case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return f16vec4(0);
|
||||
}
|
||||
}
|
||||
|
||||
f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) {
|
||||
switch (FaTypeV) {
|
||||
case 0u: return f16vec4(decodeBufF32(bl_in).block);
|
||||
case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock);
|
||||
case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock);
|
||||
case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock);
|
||||
case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock);
|
||||
case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock);
|
||||
case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock);
|
||||
default: return f16vec4(0);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GL_NV_cooperative_matrix_decode_vector
|
||||
#define FADECODEK , faDecodeK, faDecodeKVector
|
||||
#define FADECODEV , faDecodeV, faDecodeVVector
|
||||
#else
|
||||
#define FADECODEK , faDecodeK
|
||||
#define FADECODEV , faDecodeV
|
||||
#endif
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
||||
@@ -259,7 +297,7 @@ void main() {
|
||||
// F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128.
|
||||
const bool k_use_decode = (bs_k > 1u);
|
||||
if (k_use_decode) {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK);
|
||||
} else {
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose);
|
||||
}
|
||||
@@ -325,7 +363,7 @@ void main() {
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
const bool v_use_decode = (bs_v > 1u);
|
||||
if (v_use_decode) {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV);
|
||||
} else {
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad));
|
||||
}
|
||||
|
||||
@@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
|
||||
#define K_PER_ITER 8
|
||||
#else
|
||||
#define K_PER_ITER 2
|
||||
#define K_PER_ITER 4
|
||||
#endif
|
||||
|
||||
|
||||
uint a_offset, b_offset, d_offset, y_offset;
|
||||
|
||||
vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) {
|
||||
// Check if the latter elements are OOB, and don't fetch B or accumulate it.
|
||||
OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
||||
OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols);
|
||||
OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols);
|
||||
|
||||
if (!OOB_w) {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3]));
|
||||
} else if (!OOB_z) {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]),
|
||||
0);
|
||||
} else if (!OOB_y) {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]),
|
||||
0, 0);
|
||||
} else {
|
||||
return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]),
|
||||
0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||
{
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
@@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
||||
|
||||
#if K_PER_ITER == 8
|
||||
#if QUANT_R == 2
|
||||
// Note that we end up fetching bogus elements here, but its fine as they'll be
|
||||
// within an accessible block.
|
||||
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
|
||||
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
|
||||
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
|
||||
@@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
||||
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
|
||||
#endif
|
||||
#else
|
||||
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
||||
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
||||
// quantized formats since they'll be within the same block. We should
|
||||
// probably skip fetching the second element for F16/F32, but as of now we
|
||||
// still do.
|
||||
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
||||
bool OOB_y;
|
||||
bool OOB_z;
|
||||
bool OOB_w;
|
||||
|
||||
FLOAT_TYPE b0 = 0, b1 = 0;
|
||||
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
|
||||
if (!OOB) {
|
||||
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
|
||||
}
|
||||
const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w);
|
||||
#endif
|
||||
uint ibi = first_row*p.ncols;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
@@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
||||
|
||||
temp[j][n] += rowtmp;
|
||||
#else
|
||||
const vec2 v = dequantize(ib, iqs, a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
|
||||
if (!OOB) {
|
||||
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
|
||||
if (!OOB_w) {
|
||||
const vec4 v = dequantize4(ib, iqs, a_offset);
|
||||
temp[j][n] += dot(v, b);
|
||||
} else if (!OOB_z) {
|
||||
const vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset);
|
||||
const vec3 v = vec3(v0.x, v0.y, v1);
|
||||
const vec3 b0 = vec3(b.x, b.y, b.z);
|
||||
temp[j][n] += dot(v, b0);
|
||||
} else if (!OOB_y) {
|
||||
const vec2 v0 = dequantize(ib, iqs, a_offset);
|
||||
const vec2 b0 = vec2(b.x, b.y);
|
||||
temp[j][n] += dot(v0, b0);
|
||||
} else {
|
||||
const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset);
|
||||
temp[j][n] = fma(v, b.x, temp[j][n]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i)
|
||||
{
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
||||
const uint iqs = 0; // quant index
|
||||
const uint iybs = col; // y block start index
|
||||
|
||||
const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
|
||||
|
||||
uint ibi = first_row*p.ncols;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib = (ibi + col)/QUANT_K; // block index
|
||||
ibi += p.ncols;
|
||||
|
||||
const vec4 v = dequantize4_2aligned(ib, iqs, a_offset);
|
||||
|
||||
// matrix multiplication
|
||||
temp[j][n] += dot(v, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
const bool is_aligned_nonquant =
|
||||
p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 &&
|
||||
p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 &&
|
||||
K_PER_ITER == 4;
|
||||
|
||||
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
|
||||
@@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
uint i = 0;
|
||||
|
||||
#if K_PER_ITER == 4
|
||||
// If the K dimension is odd, we need lastiter==true on the last iteration
|
||||
// so OOB is computed correctly. Skip some unrolling to make that happen.
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
if ((p.ncols & 3) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
if (is_aligned_nonquant) {
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
@@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
#if K_PER_ITER == 4
|
||||
}
|
||||
#endif
|
||||
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
#if K_PER_ITER == 4
|
||||
if ((p.ncols & 3) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (is_aligned_nonquant) {
|
||||
while (i < unrolled_iters && is_aligned_nonquant) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
@@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
#if K_PER_ITER == 4
|
||||
}
|
||||
#endif
|
||||
|
||||
#if K_PER_ITER == 4
|
||||
if (is_aligned_nonquant) {
|
||||
while (i < num_iters) {
|
||||
iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
} else {
|
||||
#endif
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
|
||||
i++;
|
||||
}
|
||||
#if K_PER_ITER == 4
|
||||
}
|
||||
#endif
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
@@ -164,6 +259,6 @@ void main() {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,10 +71,12 @@ layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if QUANT_K > 1
|
||||
#define DECODEFUNCA , dequantFuncA
|
||||
|
||||
#include "dequant_funcs_cm2.glsl"
|
||||
|
||||
#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector)
|
||||
#define DECODEFUNCA , dequantFuncA, dequantFuncA_v
|
||||
#else
|
||||
#define DECODEFUNCA , dequantFuncA
|
||||
#endif
|
||||
#else
|
||||
#define DECODEFUNCA
|
||||
#endif
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
#else
|
||||
#define A_TYPE float16_t
|
||||
#endif
|
||||
#define A_TYPE_PACKED32 f16vec2
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
@@ -44,6 +45,7 @@
|
||||
#else
|
||||
#define A_TYPE uint16_t
|
||||
#endif
|
||||
#define A_TYPE_PACKED32 uint32_t
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q4_0 32
|
||||
@@ -1722,11 +1724,18 @@ struct block_nvfp4
|
||||
uint8_t qs[QUANT_K_NVFP4 / 2];
|
||||
};
|
||||
|
||||
struct block_nvfp4_packed32
|
||||
{
|
||||
uint32_t d[QUANT_K_NVFP4 / 16 / 4];
|
||||
uint32_t qs[QUANT_K_NVFP4 / 2 / 4];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_NVFP4)
|
||||
#define QUANT_K QUANT_K_NVFP4
|
||||
#define QUANT_R QUANT_R_NVFP4
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_nvfp4
|
||||
#define A_TYPE_PACKED32 block_nvfp4_packed32
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
|
||||
|
||||
@@ -749,8 +749,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst),
|
||||
};
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
uint32_t wg_x;
|
||||
uint32_t wg_y;
|
||||
uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx,
|
||||
@@ -974,9 +977,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t wg_x;
|
||||
uint32_t wg_y;
|
||||
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
@@ -1064,9 +1068,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx,
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t wg_x;
|
||||
uint32_t wg_y;
|
||||
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
@@ -1689,14 +1694,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
gathered_count_ids_binding_size),
|
||||
};
|
||||
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
const uint32_t gather_total_wg = param_n_expert;
|
||||
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
|
||||
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
|
||||
// n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B)
|
||||
const uint32_t gather_wg_x = param_n_expert;
|
||||
|
||||
dispatches.push_back({
|
||||
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y }
|
||||
gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 }
|
||||
});
|
||||
|
||||
// params for mul_mat_id.wgsl
|
||||
@@ -1748,7 +1750,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
|
||||
uint32_t total_wg = wg_m * max_wg_n;
|
||||
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
|
||||
dispatches.push_back({
|
||||
main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y }
|
||||
@@ -2771,10 +2773,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
|
||||
block_size, npr, nrows
|
||||
};
|
||||
|
||||
const uint32_t total_wg_init = npr * nrows;
|
||||
const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
const uint32_t wg_x_init = std::min(total_wg_init, max_wg);
|
||||
const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
|
||||
uint32_t wg_x_init;
|
||||
uint32_t wg_y_init;
|
||||
const uint32_t total_wg_init = npr * nrows;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init);
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> init_entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src),
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size)
|
||||
@@ -2831,9 +2835,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor *
|
||||
ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out)
|
||||
};
|
||||
|
||||
uint32_t wg_x_merge;
|
||||
uint32_t wg_y_merge;
|
||||
const uint32_t total_wg_merge = nm * nrows;
|
||||
const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg);
|
||||
const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge);
|
||||
compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge);
|
||||
|
||||
dispatches.push_back({
|
||||
argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge }
|
||||
});
|
||||
@@ -2953,9 +2959,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||
uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg);
|
||||
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
|
||||
|
||||
uint32_t wg_x;
|
||||
uint32_t wg_y;
|
||||
uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
|
||||
@@ -49,12 +49,14 @@ struct Params{
|
||||
var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
fn main(
|
||||
@builtin(global_invocation_index) gindex: u32,
|
||||
) {
|
||||
if (gindex >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
var i = gindex;
|
||||
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||
let i2 = i / (params.src_ne1 * params.src_ne0);
|
||||
@@ -62,7 +64,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let i1 = i / params.src_ne0;
|
||||
let i0 = i % params.src_ne0;
|
||||
|
||||
var j = gid.x;
|
||||
var j = gindex;
|
||||
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||
let j2 = j / (params.dst_ne1 * params.dst_ne0);
|
||||
|
||||
@@ -21,35 +21,32 @@ var<workgroup> count:atomic<u32>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup
|
||||
let own_expert = wg_id.x; // the expert assigned to this workgroup
|
||||
|
||||
if (own_expert < params.n_expert) {
|
||||
if (thread_id == 0u) {
|
||||
atomicStore(&count, 0);
|
||||
}
|
||||
if (thread_id == 0u) {
|
||||
atomicStore(&count, 0);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
workgroupBarrier();
|
||||
|
||||
for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
|
||||
let row = i / params.n_expert_used;
|
||||
let col = i % params.n_expert_used;
|
||||
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
|
||||
if (own_expert == expert) {
|
||||
let pos = atomicAdd(&count, 1u);
|
||||
let gathered_id = own_expert * params.n_tokens + pos;
|
||||
global_gathered_expert_used[gathered_id] = col;
|
||||
global_gathered_tokens[gathered_id] = row;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (thread_id == 0u) {
|
||||
gathered_count_ids[own_expert] = atomicLoad(&count);
|
||||
for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
|
||||
let row = i / params.n_expert_used;
|
||||
let col = i % params.n_expert_used;
|
||||
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
|
||||
if (own_expert == expert) {
|
||||
let pos = atomicAdd(&count, 1u);
|
||||
let gathered_id = own_expert * params.n_tokens + pos;
|
||||
global_gathered_expert_used[gathered_id] = col;
|
||||
global_gathered_tokens[gathered_id] = row;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (thread_id == 0u) {
|
||||
gathered_count_ids[own_expert] = atomicLoad(&count);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,9 @@ opbatch=
|
||||
opqueue=
|
||||
[ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ"
|
||||
|
||||
oppoll=
|
||||
[ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP"
|
||||
|
||||
opflt=
|
||||
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
|
||||
|
||||
@@ -66,7 +69,7 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $opflt $vmem $mbuf \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 1024 -fa on \
|
||||
|
||||
@@ -42,6 +42,15 @@ ndev=
|
||||
hb=
|
||||
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||
|
||||
opbatch=
|
||||
[ "$OB" != "" ] && opbatch="GGML_HEXAGON_OPBATCH=$OB"
|
||||
|
||||
opqueue=
|
||||
[ "$OQ" != "" ] && opqueue="GGML_HEXAGON_OPQUEUE=$OQ"
|
||||
|
||||
oppoll=
|
||||
[ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP"
|
||||
|
||||
set -x
|
||||
|
||||
tool=$1; shift
|
||||
@@ -50,5 +59,5 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb ./$branch/bin/$tool $@ \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll ./$branch/bin/$tool $@ \
|
||||
"
|
||||
|
||||
Reference in New Issue
Block a user