mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-12 14:03:20 +02:00
Compare commits
5 Commits
gg/kv-fix-
...
b5613
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
201b31dc2e | ||
|
|
e21d2d4ae2 | ||
|
|
dc0623fddb | ||
|
|
87d34b381d | ||
|
|
b460d16ae8 |
@@ -8,6 +8,7 @@
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Docker](#docker)
|
||||
- [Linux](#linux)
|
||||
- [Environment variable setup](#environment-variable-setup)
|
||||
- [TODO](#todo)
|
||||
|
||||
|
||||
@@ -290,5 +291,24 @@ Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang
|
||||
|
||||
We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.
|
||||
|
||||
## Environment variable setup
|
||||
|
||||
### GGML_CANN_ASYNC_MODE
|
||||
|
||||
Enables asynchronous operator submission. Disabled by default.
|
||||
|
||||
### GGML_CANN_MEM_POOL
|
||||
|
||||
Specifies the memory pool management strategy:
|
||||
|
||||
- vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool.
|
||||
|
||||
- prio: Employs a priority queue-based memory pool management.
|
||||
- leg: Uses a fixed-size buffer pool.
|
||||
|
||||
### GGML_CANN_DISABLE_BUF_POOL_CLEAN
|
||||
|
||||
Controls automatic cleanup of the memory pool. This option is only effective when using the prio or leg memory pool strategies.
|
||||
|
||||
## TODO
|
||||
- Support more models and data types.
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
||||
#include "../include/ggml-cann.h"
|
||||
#include "../include/ggml.h"
|
||||
@@ -103,6 +104,9 @@ const ggml_cann_device_info& ggml_cann_info();
|
||||
void ggml_cann_set_device(int32_t device);
|
||||
int32_t ggml_cann_get_device();
|
||||
|
||||
std::optional<std::string> get_env(const std::string& name);
|
||||
bool parse_bool(const std::string& value);
|
||||
|
||||
/**
|
||||
* @brief Abstract base class for memory pools used by CANN.
|
||||
*/
|
||||
@@ -354,7 +358,8 @@ struct ggml_backend_cann_context {
|
||||
: device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
|
||||
|
||||
bool async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
||||
device, async_mode ? "ON" : "OFF");
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <chrono>
|
||||
#include <unordered_set>
|
||||
#include <optional>
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
@@ -93,6 +95,26 @@ int32_t ggml_cann_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the value of the specified environment variable (name).
|
||||
* if not empty, return a std::string object
|
||||
*/
|
||||
std::optional<std::string> get_env(const std::string& name) {
|
||||
const char* val = std::getenv(name.c_str());
|
||||
if (!val) return std::nullopt;
|
||||
std::string res = std::string(val);
|
||||
std::transform(res.begin(), res.end(), res.begin(), ::tolower);
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Verify whether the environment variable is a valid value.
|
||||
*/
|
||||
bool parse_bool(const std::string& value) {
|
||||
std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"};
|
||||
return valid_values.find(value) != valid_values.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initialize the CANN device information.
|
||||
*
|
||||
@@ -214,7 +236,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
|
||||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf_prio(int device) : device(device) {
|
||||
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -410,7 +432,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
|
||||
* @param device The device ID to associate with this buffer pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_buf(int device) : device(device) {
|
||||
disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
|
||||
disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -731,16 +753,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
||||
*/
|
||||
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
||||
int device) {
|
||||
bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr);
|
||||
if (!disable_vmm && ggml_cann_info().devices[device].vmm) {
|
||||
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||
}
|
||||
bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr);
|
||||
if (enable_buf_prio) {
|
||||
std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
|
||||
|
||||
if (mem_pool_type == "prio") {
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
|
||||
}
|
||||
|
||||
if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") {
|
||||
GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
|
||||
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
|
||||
}
|
||||
|
||||
@@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
||||
const int64_t nb = k / QK_K;
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
@@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_sycl;
|
||||
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q6_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q6_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_sycl;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
@@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
||||
case GGML_TYPE_Q5_K:
|
||||
return dequantize_row_q5_K_sycl;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_sycl;
|
||||
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
return dequantize_row_q6_K_sycl_reorder;
|
||||
} else {
|
||||
return dequantize_row_q6_K_sycl;
|
||||
}
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_sycl;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
|
||||
@@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
|
||||
const int64_t ib = item_ct1.get_group(2);
|
||||
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t ip = tid / 32; // ip is 0 or 1
|
||||
const int64_t il = tid - 32 * ip; // 0...32
|
||||
const int64_t is = 8 * ip + il / 16;
|
||||
|
||||
const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
|
||||
const auto ql_offset = ib * (QK_K / 2);
|
||||
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
|
||||
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
|
||||
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
|
||||
const uint8_t * ql_ptr = base_ptr + ql_offset;
|
||||
const uint8_t * qh_ptr = base_ptr + qh_offset;
|
||||
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
|
||||
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
|
||||
|
||||
dst_t * y = yy + ib * QK_K + 128 * ip + il;
|
||||
|
||||
const uint8_t * ql = ql_ptr + 64 * ip + il;
|
||||
const uint8_t qh = *(qh_ptr + 32 * ip + il);
|
||||
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
|
||||
|
||||
y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
||||
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
||||
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
|
||||
@@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
|
||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
|
||||
!g_ggml_sycl_disable_optimize) {
|
||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||
tensor->extra = extra;
|
||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||
@@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return !g_ggml_sycl_prioritize_dmmv;
|
||||
default:
|
||||
return false;
|
||||
@@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
|
||||
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
|
||||
|
||||
const int nblocks = size / sizeof(block_q6_K);
|
||||
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
||||
|
||||
auto * ql_ptr = data_device;
|
||||
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
|
||||
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
|
||||
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
|
||||
|
||||
stream
|
||||
->parallel_for(nblocks,
|
||||
[=](auto i) {
|
||||
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
const uint8_t * ql = x[ib].ql;
|
||||
const uint8_t * qh = x[ib].qh;
|
||||
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
||||
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
||||
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
||||
|
||||
for (int j = 0; j < QK_K / 2; ++j) {
|
||||
base_ql_ptr[j] = ql[j];
|
||||
}
|
||||
for (int j = 0; j < QK_K / 4; ++j) {
|
||||
base_qh_ptr[j] = qh[j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < QK_K / 16; ++j) {
|
||||
base_scales_ptr[j] = x[ib].scales[j];
|
||||
}
|
||||
|
||||
dm_ptr[ib] = x[ib].d;
|
||||
})
|
||||
.wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
uint8_t * data_device = (uint8_t *) src0->data;
|
||||
size_t ncols = src0->ne[0];
|
||||
@@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
case GGML_TYPE_Q4_K:
|
||||
reorder_qw_q4_k(data_device, size, 0, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
reorder_qw_q6_k(data_device, size, 0, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("reorder_qw() called with unsupported type");
|
||||
break;
|
||||
|
||||
@@ -31,11 +31,10 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
|
||||
float partial_sum = 0.0f;
|
||||
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
||||
const int bx_offset = block_type::get_block_offset(ibx);
|
||||
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
|
||||
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
|
||||
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||
// Y block index that aligns with ibx
|
||||
const int iby = i * block_type::block_to_q8_1_ratio();
|
||||
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
|
||||
@@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
||||
// x block quant index when casting the quants to int
|
||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
@@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
|
||||
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
|
||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
#ifndef GGML_SYCL_QUANTS_HPP
|
||||
#define GGML_SYCL_QUANTS_HPP
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
namespace ggml_sycl_reordered {
|
||||
|
||||
|
||||
// The reordered block moves quants (qs) and scales(d) to two
|
||||
// uniform regions of memory that is contiguous in the same tensor.
|
||||
// What this means is that instead of having:
|
||||
@@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
|
||||
|
||||
template <ggml_type type> struct block_q_t;
|
||||
|
||||
|
||||
// qk number of weights / quants in a block
|
||||
// qr number of weights in a byte (described as 'before dequantization')
|
||||
// for quantization types that has low and high bits split, qr is calculated with
|
||||
@@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
||||
return { block_index * (traits::qk / traits::qr), 0 };
|
||||
}
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
|
||||
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
@@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
||||
return { block_index * (traits::qk / traits::qr), 0 };
|
||||
}
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
auto nblocks = (nrows * (ncols / traits::qk));
|
||||
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
|
||||
return { nblocks * (QK_K / 2),
|
||||
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
|
||||
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
||||
|
||||
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
|
||||
};
|
||||
|
||||
template <> struct block_q_t<GGML_TYPE_Q6_K> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK_K;
|
||||
static constexpr uint32_t qi = QI6_K;
|
||||
static constexpr uint32_t qr = QR6_K;
|
||||
static constexpr uint32_t vdr_mmvq = 1;
|
||||
};
|
||||
|
||||
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
|
||||
auto low_bits_index = block_index * (traits::qk / traits::qr);
|
||||
// the index of high bits it's after all low bits
|
||||
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
|
||||
return { low_bits_index, high_bits_index };
|
||||
}
|
||||
|
||||
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
auto nblocks = (nrows * (ncols / traits::qk));
|
||||
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
|
||||
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
|
||||
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
|
||||
return { block_scales, sb_scale };
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
||||
|
||||
@@ -284,10 +284,11 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
||||
const sycl::half2 * q8_1_ds, const int & iqs) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
|
||||
int v[q4_0_traits::vdr_mmvq];
|
||||
int u[2 * q4_0_traits::vdr_mmvq];
|
||||
|
||||
@@ -346,15 +347,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
||||
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
|
||||
using q4_k_traits = typename q4_k_block::traits;
|
||||
|
||||
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
|
||||
const int ib = ibx_offset / (QK_K / 2);
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
||||
const sycl::half2 * q8_1_ds, const int & iqs) {
|
||||
const int ib = ibx_offset.first / (QK_K / 2);
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||
const uint8_t * qs = base + ibx_offset;
|
||||
const int total_qs_bytes = nblocks * (QK_K / 2);
|
||||
const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
|
||||
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
|
||||
const uint8_t * qs = base + ibx_offset.first;
|
||||
const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
|
||||
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
|
||||
|
||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
||||
@@ -395,6 +396,66 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
||||
}
|
||||
};
|
||||
|
||||
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
|
||||
static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
|
||||
|
||||
using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
|
||||
using q6_k_traits = typename q6_k_block::traits;
|
||||
|
||||
__dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
|
||||
const int8_t * __restrict__ scales, const float d,
|
||||
const float * __restrict__ d8) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < QR6_K; ++i) {
|
||||
const int sc = scales[4 * i];
|
||||
|
||||
const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
|
||||
|
||||
const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
|
||||
|
||||
const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
|
||||
dpct::sub_sat()); // vi = (vil | vih) - 32
|
||||
|
||||
sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
|
||||
}
|
||||
|
||||
return d * sumf;
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
|
||||
const int iqs) {
|
||||
const int ib = ibx_offset.first / (QK_K / 2);
|
||||
|
||||
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||
const uint8_t * ql = base + ibx_offset.first;
|
||||
const uint8_t * qh = base + ibx_offset.second;
|
||||
const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
|
||||
const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
|
||||
|
||||
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
|
||||
const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
|
||||
const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
|
||||
|
||||
const int vl = get_int_from_uint8(ql, iqs);
|
||||
const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
|
||||
|
||||
const int8_t * scs = scales + scale_offset;
|
||||
|
||||
int u[QR6_K];
|
||||
float d8[QR6_K];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < QR6_K; ++i) {
|
||||
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
|
||||
const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
|
||||
d8[i] = ds_values[0];
|
||||
}
|
||||
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
|
||||
}
|
||||
};
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
|
||||
@@ -663,22 +663,14 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
{
|
||||
// Split into two equal parts
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
|
||||
ctx0, cur, split_point,
|
||||
cur->ne[1], cur->nb[1], 0
|
||||
));
|
||||
ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
|
||||
ctx0, cur, split_point,
|
||||
cur->ne[1], cur->nb[1],
|
||||
split_point * ggml_element_size(cur)
|
||||
));
|
||||
// TODO: these conts should not be needed
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
// Apply GELU activation function to the first part
|
||||
output_ffn_up = ggml_gelu(ctx0, output_ffn_up);
|
||||
cb(output_ffn_up, "ffn_gelu", il);
|
||||
x0 = ggml_gelu(ctx0, x0);
|
||||
cb(x0, "ffn_gelu", il);
|
||||
|
||||
// Element-wise multiplication between the activated part and the gate part
|
||||
cur = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
cb(cur, "ffn_geglu", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
@@ -462,7 +462,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
assert(dinfo.ids[i] <= n_kv);
|
||||
|
||||
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
|
||||
if (dinfo.ids[i] == n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -944,9 +944,11 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
//GGML_ASSERT(kv_self->size == n_ctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
|
||||
ggml_set_input(inp->k_shift);
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
|
||||
@@ -80,9 +80,6 @@ public:
|
||||
assert(isrc < pos.size());
|
||||
assert(idst < pos.size());
|
||||
|
||||
assert(pos[idst] == -1);
|
||||
assert(pos[isrc] != -1);
|
||||
|
||||
pos [idst] = pos [isrc];
|
||||
shift[idst] = shift[isrc];
|
||||
seq [idst] = seq [isrc];
|
||||
@@ -147,10 +144,9 @@ public:
|
||||
assert(pos[i] != -1);
|
||||
|
||||
seq_pos_rm(i);
|
||||
seq[i].reset();
|
||||
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
seq[i].reset();
|
||||
|
||||
used.erase(i);
|
||||
}
|
||||
@@ -168,7 +164,6 @@ public:
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
|
||||
@@ -197,7 +192,6 @@ public:
|
||||
seq[i].reset();
|
||||
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
|
||||
@@ -323,20 +317,21 @@ public:
|
||||
pos[i] += d;
|
||||
shift[i] += d;
|
||||
|
||||
seq_pos_add(i);
|
||||
|
||||
has_shift = true;
|
||||
|
||||
if (pos[i] < 0) {
|
||||
seq_pos_rm(i);
|
||||
|
||||
seq[i].reset();
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
|
||||
used.erase(i);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
seq_pos_add(i);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
@@ -2142,7 +2142,8 @@ struct server_context {
|
||||
|
||||
// find the slot that has been least recently used
|
||||
if (ret == nullptr) {
|
||||
int64_t t_last = ggml_time_us();
|
||||
int64_t t_last = -1;
|
||||
|
||||
for (server_slot & slot : slots) {
|
||||
// skip the slot if it is not available
|
||||
if (slot.is_processing()) {
|
||||
@@ -2150,7 +2151,7 @@ struct server_context {
|
||||
}
|
||||
|
||||
// select the current slot if the criteria match
|
||||
if (slot.t_last_used < t_last) {
|
||||
if (!ret || slot.t_last_used <= t_last) {
|
||||
t_last = slot.t_last_used;
|
||||
ret = &slot;
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ function AppLayout() {
|
||||
<>
|
||||
<Sidebar />
|
||||
<main
|
||||
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto bg-base-100"
|
||||
className="drawer-content grow flex flex-col h-screen mx-auto px-4 overflow-auto bg-base-100"
|
||||
id="main-scroll"
|
||||
>
|
||||
<Header />
|
||||
|
||||
Reference in New Issue
Block a user