mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-05-07 16:57:34 +03:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
017f090442 | ||
|
|
ffdd983fb8 | ||
|
|
793d0a7931 | ||
|
|
8bc492ebb4 | ||
|
|
e5f070a1dc |
@@ -1,4 +1,3 @@
|
||||
#include "log.h"
|
||||
#include "value.h"
|
||||
#include "runtime.h"
|
||||
#include "caps.h"
|
||||
|
||||
@@ -2693,6 +2693,39 @@ static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // A
|
||||
const struct ggml_tensor * src1 = op->src[1]; // B
|
||||
const struct ggml_tensor * dst = op; // X
|
||||
|
||||
if (!src0 || !src1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0->ne[0] != src0->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0->ne[1] != src1->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(sess);
|
||||
return true;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
|
||||
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
|
||||
return sess->c_name();
|
||||
@@ -2731,7 +2764,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) {
|
||||
case GGML_OP_CUMSUM: return HTP_OP_CUMSUM;
|
||||
case GGML_OP_FILL: return HTP_OP_FILL;
|
||||
case GGML_OP_DIAG: return HTP_OP_DIAG;
|
||||
|
||||
case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(t)) {
|
||||
case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU;
|
||||
@@ -3277,6 +3310,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_diag(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
supp = ggml_hexagon_supported_solve_tri(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ add_library(${HTP_LIB} SHARED
|
||||
cumsum-ops.c
|
||||
fill-ops.c
|
||||
diag-ops.c
|
||||
solve-tri-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
||||
@@ -103,5 +103,6 @@ int op_ssm_conv(struct htp_ops_context * octx);
|
||||
int op_cumsum(struct htp_ops_context * octx);
|
||||
int op_fill(struct htp_ops_context * octx);
|
||||
int op_diag(struct htp_ops_context * octx);
|
||||
int op_solve_tri(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
||||
@@ -82,7 +82,7 @@ enum htp_op_code {
|
||||
HTP_OP_CUMSUM,
|
||||
HTP_OP_FILL,
|
||||
HTP_OP_DIAG,
|
||||
|
||||
HTP_OP_SOLVE_TRI,
|
||||
HTP_OP_INVALID
|
||||
};
|
||||
|
||||
|
||||
@@ -256,6 +256,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
@@ -273,6 +285,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
|
||||
return Q6_Vhf_vmpy_VhfVhf(a, b);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_vadd_VsfVsf(a, b);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_vsub_VsfVsf(a, b);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
return Q6_Vsf_vmpy_VsfVsf(a, b);
|
||||
}
|
||||
|
||||
#endif // __HVX_ARCH__ < 79
|
||||
|
||||
#endif /* HVX_BASE_H */
|
||||
|
||||
@@ -573,6 +573,9 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_DIAG:
|
||||
return op_diag(octx);
|
||||
|
||||
case HTP_OP_SOLVE_TRI:
|
||||
return op_solve_tri(octx);
|
||||
|
||||
case HTP_OP_INVALID:
|
||||
break;
|
||||
|
||||
|
||||
267
ggml/src/ggml-hexagon/htp/solve-tri-ops.c
Normal file
267
ggml/src/ggml-hexagon/htp/solve-tri-ops.c
Normal file
@@ -0,0 +1,267 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-types.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
struct htp_solve_tri_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t jobs_per_thread;
|
||||
uint32_t total_jobs;
|
||||
uint32_t k_chunks;
|
||||
uint32_t col_block;
|
||||
};
|
||||
|
||||
static inline void solve_tri_row_scalar(const float * A_row,
|
||||
const float * B_row,
|
||||
float * X,
|
||||
uint32_t row,
|
||||
uint32_t k,
|
||||
uint32_t col0,
|
||||
uint32_t coln,
|
||||
float inv_diag) {
|
||||
for (uint32_t col = col0; col < col0 + coln; ++col) {
|
||||
float sum = 0.0f;
|
||||
for (uint32_t t = 0; t < row; ++t) {
|
||||
sum += A_row[t] * X[t * k + col];
|
||||
}
|
||||
X[row * k + col] = (B_row[col] - sum) * inv_diag;
|
||||
}
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) {
|
||||
HVX_Vector v = *((const HVX_UVector *) src);
|
||||
HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float));
|
||||
return Q6_V_vmux_QVV(mask, v, Q6_V_vzero());
|
||||
}
|
||||
|
||||
static inline void solve_tri_row_hvx(const float * A_row,
|
||||
const float * B_row,
|
||||
float * X,
|
||||
uint32_t row,
|
||||
uint32_t k,
|
||||
uint32_t col0,
|
||||
uint32_t coln,
|
||||
float inv_diag) {
|
||||
const bool full = (coln == VLEN_FP32);
|
||||
|
||||
HVX_Vector sum_v = Q6_V_vzero();
|
||||
for (uint32_t t = 0; t < row; ++t) {
|
||||
const float a = A_row[t];
|
||||
const float * x_row_col = X + t * k + col0;
|
||||
|
||||
HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln);
|
||||
HVX_Vector a_v = hvx_vec_splat_f32(a);
|
||||
sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v));
|
||||
}
|
||||
|
||||
const float * b_row_col = B_row + col0;
|
||||
float * x_out_col = X + row * k + col0;
|
||||
|
||||
HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln);
|
||||
HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag);
|
||||
|
||||
HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v);
|
||||
hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v);
|
||||
}
|
||||
|
||||
// Batch-level thread: each job is one full batch.
|
||||
static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data;
|
||||
struct htp_ops_context * octx = sctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = octx->src[0]; // A
|
||||
const struct htp_tensor * src1 = octx->src[1]; // B
|
||||
const struct htp_tensor * dst = octx->dst; // X
|
||||
|
||||
const uint32_t n = src0->ne[0];
|
||||
const uint32_t k = src1->ne[0];
|
||||
|
||||
const uint32_t ne02 = src0->ne[2];
|
||||
|
||||
const uint32_t col_block = VLEN_FP32;
|
||||
const uint32_t k_full = (k / col_block) * col_block;
|
||||
|
||||
const uint32_t start_batch = sctx->jobs_per_thread * ith;
|
||||
const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs);
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
for (uint32_t batch = start_batch; batch < end_batch; ++batch) {
|
||||
const uint32_t i03 = batch / ne02;
|
||||
const uint32_t i02 = batch - i03 * ne02;
|
||||
|
||||
const float * A_batch =
|
||||
(const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]);
|
||||
const float * B_batch =
|
||||
(const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]);
|
||||
float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]);
|
||||
|
||||
for (uint32_t row = 0; row < n; ++row) {
|
||||
const float diag = A_batch[row * n + row];
|
||||
const float inv_diag = 1.0f / diag;
|
||||
const float * A_row = A_batch + row * n;
|
||||
const float * B_row = B_batch + row * k;
|
||||
|
||||
uint32_t col0 = 0;
|
||||
for (; col0 < k_full; col0 += col_block) {
|
||||
solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag);
|
||||
}
|
||||
|
||||
if (col0 < k) {
|
||||
const uint32_t coln = k - col0;
|
||||
if (coln >= 8) {
|
||||
solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
|
||||
} else {
|
||||
solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n",
|
||||
ith, nth, n, n, k, n, start_batch, end_batch,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
// Chunk-level thread: each job is one (batch, col_chunk) pair.
|
||||
static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data;
|
||||
struct htp_ops_context * octx = sctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = octx->src[0]; // A
|
||||
const struct htp_tensor * src1 = octx->src[1]; // B
|
||||
const struct htp_tensor * dst = octx->dst; // X
|
||||
|
||||
const uint32_t n = src0->ne[0];
|
||||
const uint32_t k = src1->ne[0];
|
||||
|
||||
const uint32_t ne02 = src0->ne[2];
|
||||
|
||||
const uint32_t start_job = sctx->jobs_per_thread * ith;
|
||||
const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs);
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
for (uint32_t job = start_job; job < end_job; ++job) {
|
||||
const uint32_t batch = job / sctx->k_chunks;
|
||||
const uint32_t chunk = job - batch * sctx->k_chunks;
|
||||
|
||||
const uint32_t i03 = batch / ne02;
|
||||
const uint32_t i02 = batch - i03 * ne02;
|
||||
|
||||
const uint32_t col0 = chunk * sctx->col_block;
|
||||
const uint32_t coln = MIN(sctx->col_block, k - col0);
|
||||
|
||||
const float * A_batch =
|
||||
(const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]);
|
||||
const float * B_batch =
|
||||
(const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]);
|
||||
float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]);
|
||||
|
||||
const bool use_hvx = (coln >= 8);
|
||||
|
||||
for (uint32_t row = 0; row < n; ++row) {
|
||||
const float diag = A_batch[row * n + row];
|
||||
const float inv_diag = 1.0f / diag;
|
||||
|
||||
const float * A_row = A_batch + row * n;
|
||||
const float * B_row = B_batch + row * k;
|
||||
|
||||
if (use_hvx) {
|
||||
solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
|
||||
} else {
|
||||
solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n",
|
||||
ith, nth, n, n, k, n, start_job, end_job,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
int op_solve_tri(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = octx->src[0]; // A
|
||||
const struct htp_tensor * src1 = octx->src[1]; // B
|
||||
const struct htp_tensor * dst = octx->dst; // X
|
||||
|
||||
if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
// left=true, lower=true, uni=false only
|
||||
if (src0->ne[0] != src0->ne[1]) {
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
if (src0->ne[1] != src1->ne[1]) {
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) {
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] ||
|
||||
dst->ne[3] != src1->ne[3]) {
|
||||
return HTP_STATUS_INVAL_PARAMS;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const uint32_t k = src1->ne[0];
|
||||
|
||||
const uint32_t col_block = VLEN_FP32;
|
||||
const uint32_t k_chunks = (k + col_block - 1) / col_block;
|
||||
const uint32_t total_batches = src0->ne[2] * src0->ne[3];
|
||||
const bool batched = total_batches >= (uint32_t) octx->n_threads;
|
||||
|
||||
FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n",
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched);
|
||||
|
||||
if (batched) {
|
||||
// Batch-level parallelism
|
||||
const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches);
|
||||
|
||||
struct htp_solve_tri_context sctx = {
|
||||
.octx = octx,
|
||||
.jobs_per_thread = (total_batches + n_threads - 1) / n_threads,
|
||||
.total_jobs = total_batches,
|
||||
.k_chunks = k_chunks,
|
||||
.col_block = col_block,
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads);
|
||||
} else {
|
||||
// Chunk-level parallelism
|
||||
const uint32_t total_jobs = total_batches * k_chunks;
|
||||
const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1));
|
||||
|
||||
struct htp_solve_tri_context sctx = {
|
||||
.octx = octx,
|
||||
.jobs_per_thread = (total_jobs + n_threads - 1) / n_threads,
|
||||
.total_jobs = total_jobs,
|
||||
.k_chunks = k_chunks,
|
||||
.col_block = col_block,
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
@@ -197,11 +197,12 @@ struct ggml_webgpu_row_norm_pipeline_key_hash {
|
||||
/** RMS_NORM + MUL **/
|
||||
|
||||
struct ggml_webgpu_rms_norm_mul_pipeline_key {
|
||||
bool inplace;
|
||||
bool src_overlap;
|
||||
bool inplace; // rn_src == dst
|
||||
bool overlap; // mul_src == dst
|
||||
bool src_overlap; // rn_src == mul_src
|
||||
|
||||
bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const {
|
||||
return inplace == other.inplace && src_overlap == other.src_overlap;
|
||||
return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -209,6 +210,7 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.overlap);
|
||||
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
||||
return seed;
|
||||
}
|
||||
@@ -556,7 +558,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!key.kv_direct) {
|
||||
bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v);
|
||||
}
|
||||
@@ -1878,6 +1880,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_rms_norm_mul_pipeline_key key = {};
|
||||
key.inplace = context.inplace;
|
||||
key.overlap = context.overlap;
|
||||
key.src_overlap = context.src_overlap;
|
||||
|
||||
auto it = rms_norm_mul_pipelines.find(key);
|
||||
@@ -1892,6 +1895,9 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
} else if (key.overlap) {
|
||||
defines.push_back("OVERLAP");
|
||||
variant += "_overlap";
|
||||
} else if (key.src_overlap) {
|
||||
defines.push_back("SRC_OVERLAP");
|
||||
variant += "_src_overlap";
|
||||
|
||||
@@ -2071,8 +2071,9 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1");
|
||||
}
|
||||
|
||||
bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) ||
|
||||
(ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst));
|
||||
bool inplace = ggml_webgpu_tensor_equal(rn_src, dst);
|
||||
bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src);
|
||||
|
||||
uint32_t offset_merged_rn_src = 0;
|
||||
@@ -2116,7 +2117,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
if (inplace) {
|
||||
if (inplace || overlap) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src));
|
||||
} else if (src_overlap) {
|
||||
@@ -2136,6 +2137,7 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.inplace = inplace;
|
||||
shader_lib_ctx.overlap = overlap;
|
||||
shader_lib_ctx.src_overlap = src_overlap;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#ifdef INPLACE
|
||||
#ifdef OVERLAP
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> rn_src: array<f32>;
|
||||
@@ -13,6 +13,21 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32)
|
||||
mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
|
||||
}
|
||||
|
||||
#elif INPLACE
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> rn_src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> mul_src: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) {
|
||||
rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset];
|
||||
}
|
||||
|
||||
#elif SRC_OVERLAP
|
||||
|
||||
@group(0) @binding(0)
|
||||
|
||||
@@ -675,6 +675,10 @@ private:
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
|
||||
// set to llama_model_n_swa(model)
|
||||
// if swa_full is enabled, this is set to 0 to simulate a non-SWA model
|
||||
int32_t n_swa;
|
||||
|
||||
// slots / clients
|
||||
std::vector<server_slot> slots;
|
||||
|
||||
@@ -719,7 +723,7 @@ private:
|
||||
return;
|
||||
}
|
||||
SLT_INF(slot, "%s", "saving idle slot to prompt cache\n");
|
||||
SLT_DBG(slot, "%s", "__TEST_TAG_CLEAR_IDLE_SLOT__\n");
|
||||
SLT_DBG(slot, "%s", "__TEST_TAG_CACHE_IDLE_SLOT__\n");
|
||||
slot.prompt_save(*prompt_cache);
|
||||
slot.prompt_clear(false);
|
||||
prompt_cache->update();
|
||||
@@ -854,6 +858,8 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
n_swa = params_base.swa_full ? 0 : llama_model_n_swa(model);
|
||||
|
||||
// Necessary similarity of prompt for slot selection
|
||||
slot_prompt_similarity = params_base.slot_prompt_similarity;
|
||||
|
||||
@@ -996,7 +1002,7 @@ private:
|
||||
params_base.cache_idle_slots = false;
|
||||
} else {
|
||||
SRV_INF("%s: idle slots will be saved to prompt cache and cleared upon starting a new task\n", __func__);
|
||||
SRV_DBG("%s", "__TEST_TAG_CLEAR_IDLE_ENABLED__\n");
|
||||
SRV_DBG("%s", "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__\n");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2415,9 +2421,6 @@ private:
|
||||
|
||||
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA
|
||||
const auto n_swa = std::max(0, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, pos_next - n_swa);
|
||||
|
||||
@@ -2589,10 +2592,10 @@ private:
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model does not support partial sequence removal
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model uses SWA (and we are not using `swa_full`)
|
||||
do_checkpoint = do_checkpoint && (
|
||||
(slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full));
|
||||
(n_swa > 0));
|
||||
|
||||
bool has_mtmd = false;
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def test_clear_and_restore():
|
||||
log = LogReader(server.log_path)
|
||||
|
||||
# verify feature is enabled
|
||||
assert "__TEST_TAG_CLEAR_IDLE_ENABLED__" in log.drain()
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" in log.drain()
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": LONG_PROMPT,
|
||||
@@ -59,7 +59,7 @@ def test_clear_and_restore():
|
||||
original_prompt_n = res.body["timings"]["prompt_n"]
|
||||
|
||||
# Slot 0 is the only slot with KV — should NOT be cleared
|
||||
assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain()
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
|
||||
|
||||
# Launching slot 1 clears idle slot 0
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
@@ -68,7 +68,7 @@ def test_clear_and_restore():
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__TEST_TAG_CLEAR_IDLE_SLOT__" in log.drain()
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" in log.drain()
|
||||
|
||||
# Re-send same prompt — should restore from cache-ram
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
@@ -86,7 +86,7 @@ def test_clear_and_restore():
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain()
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
|
||||
|
||||
|
||||
def test_disabled_with_flag():
|
||||
@@ -96,7 +96,7 @@ def test_disabled_with_flag():
|
||||
log = LogReader(server.log_path)
|
||||
|
||||
# Feature should not be enabled
|
||||
assert "__TEST_TAG_CLEAR_IDLE_ENABLED__" not in log.drain()
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__" not in log.drain()
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": LONG_PROMPT,
|
||||
@@ -112,4 +112,4 @@ def test_disabled_with_flag():
|
||||
"cache_prompt": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "__TEST_TAG_CLEAR_IDLE_SLOT__" not in log.drain()
|
||||
assert "__TEST_TAG_CACHE_IDLE_SLOT__" not in log.drain()
|
||||
|
||||
Reference in New Issue
Block a user