mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-03-05 14:33:24 +02:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4720819d45 | ||
|
|
d979f2b176 | ||
|
|
ecbcb7ea9d | ||
|
|
3e6ab244ad | ||
|
|
5596a35791 | ||
|
|
8d3b962f47 | ||
|
|
d903f30e25 | ||
|
|
8387ffb28d | ||
|
|
2e7e638523 | ||
|
|
a8b192b6ec | ||
|
|
c17dce4f5c | ||
|
|
88cf781f51 | ||
|
|
4e76d24f28 | ||
|
|
723c71064d |
2
.github/workflows/gguf-publish.yml
vendored
2
.github/workflows/gguf-publish.yml
vendored
@@ -21,7 +21,7 @@ on:
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-slim
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
@@ -2520,11 +2520,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"-a", "--alias"}, "STRING",
|
||||
"set alias for model name (to be used by REST API)",
|
||||
"set model name aliases, comma-separated (to be used by API)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model_alias = value;
|
||||
for (auto & alias : string_split<std::string>(value, ',')) {
|
||||
alias = string_strip(alias);
|
||||
if (!alias.empty()) {
|
||||
params.model_alias.insert(alias);
|
||||
}
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS"));
|
||||
add_opt(common_arg(
|
||||
{"--tags"}, "STRING",
|
||||
"set model tags, comma-separated (informational, not used for routing)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (auto & tag : string_split<std::string>(value, ',')) {
|
||||
tag = string_strip(tag);
|
||||
if (!tag.empty()) {
|
||||
params.model_tags.insert(tag);
|
||||
}
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TAGS"));
|
||||
add_opt(common_arg(
|
||||
{"-m", "--model"}, "FNAME",
|
||||
ex == LLAMA_EXAMPLE_EXPORT_LORA
|
||||
|
||||
@@ -410,7 +410,8 @@ struct common_params {
|
||||
|
||||
struct common_params_model model;
|
||||
|
||||
std::string model_alias = ""; // model alias // NOLINT
|
||||
std::set<std::string> model_alias; // model aliases // NOLINT
|
||||
std::set<std::string> model_tags; // model tags (informational, not used for routing) // NOLINT
|
||||
std::string hf_token = ""; // HF token // NOLINT
|
||||
std::string prompt = ""; // NOLINT
|
||||
std::string system_prompt = ""; // NOLINT
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
|
||||
**Llama.cpp + ZenDNN**
|
||||
|
||||
The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL BLIS, LibXSMM, OneDNN).
|
||||
The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL DLP, LibXSMM, OneDNN).
|
||||
|
||||
For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendnn.html
|
||||
|
||||
@@ -32,7 +32,7 @@ For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendn
|
||||
|:-------:|:-------:|:----------------------------------------------:|
|
||||
| Linux | Support | Ubuntu 20.04, 22.04, 24.04 |
|
||||
|
||||
For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/zendnnl/README.md#15-supported-os).
|
||||
For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md#15-supported-os).
|
||||
|
||||
## Hardware
|
||||
|
||||
@@ -44,9 +44,9 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
|
||||
|
||||
| CPU Family | Status | Notes |
|
||||
|:-----------------------------:|:-------:|:----------------------------------:|
|
||||
| AMD EPYC™ 9005 Series (Turin)| Support | 5th Gen - Zen 5 architecture |
|
||||
| AMD EPYC™ 9004 Series (Genoa)| Support | 4th Gen - Zen 4 architecture |
|
||||
| AMD EPYC™ 7003 Series (Milan)| Support | 3rd Gen - Zen 3 architecture |
|
||||
| AMD EPYC™ 9005 Series (Turin) | Support | 5th Gen - Zen 5 architecture |
|
||||
| AMD EPYC™ 9004 Series (Genoa) | Support | 4th Gen - Zen 4 architecture |
|
||||
| AMD EPYC™ 7003 Series (Milan) | Support | 3rd Gen - Zen 3 architecture |
|
||||
| AMD Ryzen™ AI MAX (Strix Halo)| Support | High-performance mobile processors |
|
||||
|
||||
*Notes:*
|
||||
@@ -61,7 +61,7 @@ The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** ope
|
||||
|
||||
| Operation | Status | Notes |
|
||||
|:-------------|:-------:|:----------------------------------------------:|
|
||||
| MUL_MAT | ✓ | Accelerated via ZenDNN LowOHA MatMul |
|
||||
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
|
||||
|
||||
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
|
||||
|
||||
@@ -104,7 +104,6 @@ If you want to build ZenDNN yourself or use a specific version:
|
||||
# Clone ZenDNN repository
|
||||
git clone https://github.com/amd/ZenDNN.git
|
||||
cd ZenDNN
|
||||
git checkout zendnnl
|
||||
|
||||
# Build and install (requires CMake >= 3.25)
|
||||
mkdir build && cd build
|
||||
@@ -114,7 +113,7 @@ cmake --build . --target all
|
||||
|
||||
Default installation path: `ZenDNN/build/install`
|
||||
|
||||
**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/zendnnl/README.md).
|
||||
**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md).
|
||||
|
||||
**Step 2: Build llama.cpp with custom ZenDNN path**
|
||||
|
||||
@@ -146,8 +145,7 @@ Run llama.cpp server with ZenDNN acceleration:
|
||||
|
||||
```sh
|
||||
# Set optimal configuration
|
||||
export OMP_NUM_THREADS=64 # Adjust to your CPU core count
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance
|
||||
export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo for best performance
|
||||
|
||||
# Start server
|
||||
./build/bin/llama-server \
|
||||
@@ -160,62 +158,26 @@ export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance
|
||||
Access the server at `http://localhost:8080`.
|
||||
|
||||
**Performance tips**:
|
||||
- Set `OMP_NUM_THREADS` to match your physical core count
|
||||
- Use `ZENDNNL_MATMUL_ALGO=2` for optimal performance
|
||||
- Use `ZENDNNL_MATMUL_ALGO=1` for optimal performance
|
||||
- For NUMA systems: `numactl --cpunodebind=0 --membind=0 ./build/bin/llama-server ...`
|
||||
|
||||
## Environment Variable
|
||||
|
||||
### Build Time
|
||||
For environment variables related to ZenDNN, refer to the [ZenDNN Environment Variables Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md).
|
||||
|
||||
| Name | Value | Function |
|
||||
|--------------------|---------------------------------------|---------------------------------------------|
|
||||
| GGML_ZENDNN | ON/OFF | Enable ZenDNN backend support |
|
||||
| ZENDNN_ROOT | Path to ZenDNN installation | Set ZenDNN installation directory |
|
||||
| GGML_OPENMP | ON/OFF (recommended: ON) | Enable OpenMP for multi-threading |
|
||||
### Performance Optimization
|
||||
|
||||
### Runtime
|
||||
|
||||
| Name | Value | Function |
|
||||
|-------------------------|--------------------------|-------------------------------------------------------------------|
|
||||
| OMP_NUM_THREADS | Number (e.g., 64) | Set number of OpenMP threads (recommended: physical core count) |
|
||||
| ZENDNNL_MATMUL_ALGO | 0-5 | Select MatMul backend algorithm (see Performance Optimization) |
|
||||
| ZENDNNL_PROFILE_LOG_LEVEL | 0-4 | Profiling log level (0=disabled, 4=verbose) |
|
||||
| ZENDNNL_ENABLE_PROFILER | 0 or 1 | Enable detailed profiling (1=enabled) |
|
||||
| ZENDNNL_API_LOG_LEVEL | 0-4 | API log level (0=disabled, 4=verbose) |
|
||||
|
||||
**Example**:
|
||||
ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL DLP** algorithm:
|
||||
|
||||
```sh
|
||||
export OMP_NUM_THREADS=64
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Use Blocked AOCL BLIS for best performance
|
||||
./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Test" -n 100
|
||||
export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo (recommended)
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### MatMul Algorithm Selection
|
||||
|
||||
ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL BLIS** algorithm:
|
||||
|
||||
```sh
|
||||
export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS (recommended)
|
||||
```
|
||||
|
||||
**Available algorithms**:
|
||||
|
||||
| Value | Algorithm | Description |
|
||||
|:-----:|:-----------------------|:----------------------------------------------|
|
||||
| 0 | Dynamic Dispatch | Automatic backend selection (default) |
|
||||
| 1 | AOCL BLIS | AOCL BLIS backend |
|
||||
| 2 | AOCL BLIS Blocked | **Blocked AOCL BLIS (recommended)** |
|
||||
| 3 | OneDNN | OneDNN backend |
|
||||
| 4 | OneDNN Blocked | Blocked OneDNN |
|
||||
| 5 | LibXSMM | LibXSMM backend |
|
||||
For more details on available algorithms, see the [ZenDNN MatMul Algorithm Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md#algorithm-details).
|
||||
|
||||
### Profiling and Debugging
|
||||
|
||||
For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/zendnnl/docs/logging.md).
|
||||
For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/logging.md).
|
||||
|
||||
## Known Issues
|
||||
|
||||
@@ -245,10 +207,9 @@ A: Currently, ZenDNN primarily supports FP32 and BF16 data types. Quantized mode
|
||||
|
||||
A: Ensure:
|
||||
1. You're using an AMD EPYC or Ryzen processor (Zen 2 or newer)
|
||||
2. `OMP_NUM_THREADS` is set appropriately (physical core count)
|
||||
3. `ZENDNNL_MATMUL_ALGO=2` is set for best performance (Blocked AOCL BLIS)
|
||||
4. You're using a sufficiently large model (small models may not benefit as much)
|
||||
5. Enable profiling to verify ZenDNN MatMul is being called
|
||||
2. `ZENDNNL_MATMUL_ALGO=1` is set for best performance (Blocked AOCL DLP)
|
||||
3. You're using a sufficiently large model (small models may not benefit as much)
|
||||
4. Enable profiling to verify ZenDNN MatMul is being called
|
||||
|
||||
### **GitHub Contribution**:
|
||||
Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-team check/address them without delay.
|
||||
|
||||
@@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
|
||||
namespace ggml::cpu::amx {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
// handle only 2d gemm for now
|
||||
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
|
||||
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
|
||||
};
|
||||
|
||||
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
|
||||
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
|
||||
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
|
||||
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
|
||||
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
|
||||
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
|
||||
// src1 must be host buffer
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
// src1 must be float32
|
||||
if (op->src[1]->type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
if (op->op != GGML_OP_MUL_MAT) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
auto * src0 = op->src[0];
|
||||
auto * src1 = op->src[1];
|
||||
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
return false;
|
||||
}
|
||||
if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
|
||||
return false;
|
||||
}
|
||||
if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
if (op->ne[0] % (TILE_N * 2)) {
|
||||
return false;
|
||||
}
|
||||
int alignment;
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
alignment = TILE_K;
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
alignment = 256; // QK_K
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
alignment = 16;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (src0->ne[0] % alignment) {
|
||||
return false;
|
||||
}
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic ignored "-Wpedantic"
|
||||
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
|
||||
@@ -202,35 +201,27 @@ struct tile_config_t{
|
||||
// advanced-matrix-extensions-intrinsics-functions.html
|
||||
//
|
||||
|
||||
#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
|
||||
void ggml_tile_config_init(void) {
|
||||
static thread_local bool is_first_time = true;
|
||||
inline void ggml_tile_config_init(void) {
|
||||
static thread_local bool done = false;
|
||||
|
||||
if (!is_first_time) {
|
||||
if (done) {
|
||||
return;
|
||||
}
|
||||
|
||||
static thread_local tile_config_t tc;
|
||||
tile_config_t current_tc;
|
||||
_tile_storeconfig(¤t_tc);
|
||||
alignas(64) tile_config_t tc = {};
|
||||
tc.palette_id = 1;
|
||||
tc.start_row = 0;
|
||||
tc.rows[0] = 8; tc.colsb[0] = 64;
|
||||
tc.rows[1] = 8; tc.colsb[1] = 64;
|
||||
tc.rows[2] = 16; tc.colsb[2] = 32;
|
||||
tc.rows[3] = 16; tc.colsb[3] = 32;
|
||||
tc.rows[4] = 16; tc.colsb[4] = 64;
|
||||
tc.rows[5] = 16; tc.colsb[5] = 64;
|
||||
tc.rows[6] = 16; tc.colsb[6] = 64;
|
||||
tc.rows[7] = 16; tc.colsb[7] = 64;
|
||||
|
||||
// load only when config changes
|
||||
if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
|
||||
memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
|
||||
tc.palette_id = 1;
|
||||
tc.start_row = 0;
|
||||
TC_CONFIG_TILE(TMM0, 8, 64);
|
||||
TC_CONFIG_TILE(TMM1, 8, 64);
|
||||
TC_CONFIG_TILE(TMM2, 16, 32);
|
||||
TC_CONFIG_TILE(TMM3, 16, 32);
|
||||
TC_CONFIG_TILE(TMM4, 16, 64);
|
||||
TC_CONFIG_TILE(TMM5, 16, 64);
|
||||
TC_CONFIG_TILE(TMM6, 16, 64);
|
||||
TC_CONFIG_TILE(TMM7, 16, 64);
|
||||
_tile_loadconfig(&tc);
|
||||
}
|
||||
|
||||
is_first_time = false;
|
||||
_tile_loadconfig(&tc);
|
||||
done = true;
|
||||
}
|
||||
|
||||
// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
|
||||
@@ -268,33 +259,6 @@ int get_row_size(int K) {
|
||||
return row_size;
|
||||
}
|
||||
|
||||
// vectorized dtype conversion
|
||||
inline float FP16_TO_FP32(ggml_half val) {
|
||||
__m256i v = _mm256_setr_epi16(
|
||||
val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
__m512 o = _mm512_cvtph_ps(v);
|
||||
return _mm512_cvtss_f32(o);
|
||||
}
|
||||
|
||||
inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
|
||||
__m256i v = _mm256_set1_epi16(val);
|
||||
return _mm512_cvtph_ps(v);
|
||||
}
|
||||
|
||||
// horizontal reduce
|
||||
inline float _mm512_reduce_max_ps(const __m512 x) {
|
||||
__m512 v = x;
|
||||
__m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
v1 = _mm512_shuffle_ps(v, v, 0x4E);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
v1 = _mm512_shuffle_ps(v, v, 0xB1);
|
||||
v = _mm512_max_ps(v, v1);
|
||||
return _mm512_cvtss_f32(v);
|
||||
}
|
||||
|
||||
// transpose utils
|
||||
#define SHUFFLE_EPI32(a, b, mask) \
|
||||
_mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
|
||||
@@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \
|
||||
K, (const float *)src1->data + mb_start * K, \
|
||||
(const type *)src0->data + nb_start * K, \
|
||||
(float *)dst->data + mb_start * ldc + nb_start, ldc);
|
||||
K, (const float *)src1->data + src1_offset + mb_start * K, \
|
||||
(const type *)src0->data + src0_offset + nb_start * K, \
|
||||
(float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
|
||||
|
||||
|
||||
// re-organize in the format {NB, KB, TILE_SIZE}:
|
||||
@@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
|
||||
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
|
||||
KB, (const char *)wdata + 0 * row_size_A, \
|
||||
(const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
|
||||
(float *) dst->data + 0 * N + nb_start, ldc)
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \
|
||||
tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \
|
||||
KB, wdata_batch, \
|
||||
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
|
||||
(float *) dst->data + dst_offset + nb_start, ldc)
|
||||
|
||||
template <typename TA, typename TB, typename TC, int BLOCK_K,
|
||||
typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>
|
||||
@@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
|
||||
_tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
|
||||
|
||||
if (need_unpack) {
|
||||
unpack_B<TB>(Tile1, B_blk0);
|
||||
unpack_B<TB>(Tile1, B_blk1);
|
||||
_tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
|
||||
} else {
|
||||
_tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
|
||||
@@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
|
||||
});
|
||||
}
|
||||
|
||||
// ne2 is passed explicitly to help compiler optimize repeated calls
|
||||
inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
|
||||
const int64_t i2 = batch_idx % ne2;
|
||||
const int64_t i3 = batch_idx / ne2;
|
||||
return i3 * t->nb[3] + i2 * t->nb[2];
|
||||
}
|
||||
|
||||
size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
|
||||
struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
@@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
|
||||
|
||||
const int M = dst->ne[1];
|
||||
const int K = src0->ne[0];
|
||||
const int64_t n_batch = dst->ne[2] * dst->ne[3];
|
||||
|
||||
size_t desired_wsize = 0;
|
||||
|
||||
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
||||
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
|
||||
desired_wsize = M * row_size_A;
|
||||
desired_wsize = n_batch * M * row_size_A;
|
||||
});
|
||||
|
||||
return desired_wsize;
|
||||
@@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
|
||||
// src1: input in shape of {M, K}, float32
|
||||
// dst: output in shape of {M, N}, float32
|
||||
//
|
||||
// the function performs: dst = src1 @ src0.T
|
||||
// the function performs: dst = src1 @ src0.T for each batch
|
||||
//
|
||||
void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
struct ggml_tensor * src0 = dst->src[0];
|
||||
@@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
const int K = src0->ne[0];
|
||||
const int ldc = dst->nb[1] / dst->nb[0];
|
||||
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t n_batch = ne2 * dst->ne[3];
|
||||
|
||||
if (is_floating_type) {
|
||||
constexpr int BLOCK_M = 4;
|
||||
constexpr int BLOCK_N = 6;
|
||||
const int MB = div_up(M, BLOCK_M);
|
||||
const int NB = div_up(N, BLOCK_N);
|
||||
|
||||
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
|
||||
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
|
||||
GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
|
||||
for (int i = begin; i < end; ++i) {
|
||||
int mb = i / NB;
|
||||
int nb = i % NB;
|
||||
int batch_idx = i / (MB * NB);
|
||||
int remaining = i % (MB * NB);
|
||||
int mb = remaining / NB;
|
||||
int nb = remaining % NB;
|
||||
|
||||
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
|
||||
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
|
||||
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
@@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
void * wdata = params->wdata;
|
||||
|
||||
//TODO: performance improvement: merge quant A
|
||||
if (params->ith == 0) {
|
||||
// if (params->ith == 0) {
|
||||
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
||||
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
|
||||
const size_t desired_wsize = M * row_size_A;
|
||||
const size_t desired_wsize = n_batch * M * row_size_A;
|
||||
if (params->wsize < desired_wsize) {
|
||||
GGML_ABORT("insufficient work space size");
|
||||
}
|
||||
@@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
|
||||
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
|
||||
|
||||
const float * A_data = static_cast<const float *>(src1->data);
|
||||
for (int m = 0; m < M; ++m) {
|
||||
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
|
||||
}
|
||||
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
|
||||
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
|
||||
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
|
||||
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
|
||||
|
||||
for (int m = 0; m < M; ++m) {
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
// }
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
@@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
constexpr int BLOCK_N = TILE_N * kTilesN;
|
||||
const int NB = div_up(N, BLOCK_N);
|
||||
|
||||
parallel_for_ggml(params, NB, [&](int begin, int end) {
|
||||
parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
|
||||
GGML_DISPATCH_QTYPES(TYPE, [&] {
|
||||
const int KB = K / blck_size;
|
||||
const int TILE_SIZE = get_tile_size<type>();
|
||||
const int row_size_A = KB * sizeof(vec_dot_type);
|
||||
for (int i = begin; i < end; ++i) {
|
||||
int nb = i;
|
||||
int batch_idx = i / NB;
|
||||
int nb = i % NB;
|
||||
|
||||
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
|
||||
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
|
||||
const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
|
||||
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
|
||||
|
||||
@@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
const int MB = div_up(M, BLOCK_M);
|
||||
const int NB = div_up(N, BLOCK_N);
|
||||
|
||||
parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
|
||||
parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
|
||||
// init tile config for each thread
|
||||
ggml_tile_config_init();
|
||||
|
||||
@@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
const int row_size_A = KB * sizeof(vec_dot_type);
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
int mb = i / NB;
|
||||
int nb = i % NB;
|
||||
int batch_idx = i / (MB * NB);
|
||||
int remaining = i % (MB * NB);
|
||||
int mb = remaining / NB;
|
||||
int nb = remaining % NB;
|
||||
|
||||
int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
|
||||
int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2);
|
||||
const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
@@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
|
||||
tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(
|
||||
mb_size, nb_size, KB,
|
||||
(const char *)wdata + mb_start * row_size_A,
|
||||
(const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
|
||||
(float *) dst->data + mb_start * N + nb_start, ldc);
|
||||
wdata_batch + mb_start * row_size_A,
|
||||
(const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
|
||||
(float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -48,6 +48,8 @@
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -62,6 +64,8 @@
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||
@@ -69,8 +73,10 @@
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
// repack.cpp
|
||||
@@ -84,6 +90,7 @@
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -94,6 +101,7 @@
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__POWERPC__) || defined(__powerpc__)
|
||||
@@ -120,6 +128,8 @@
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -134,6 +144,8 @@
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__loongarch64)
|
||||
@@ -160,6 +172,8 @@
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -174,6 +188,8 @@
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__riscv)
|
||||
@@ -201,6 +217,8 @@
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -214,6 +232,8 @@
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__s390x__)
|
||||
@@ -246,6 +266,8 @@
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -260,6 +282,8 @@
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#elif defined(__wasm__)
|
||||
@@ -294,6 +318,8 @@
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
|
||||
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
@@ -308,6 +334,8 @@
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
|
||||
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
#endif
|
||||
|
||||
@@ -498,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
float * res_ptr = s;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf = vdupq_n_f32(0);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
||||
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
||||
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
||||
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
||||
|
||||
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
|
||||
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
|
||||
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
|
||||
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
|
||||
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
|
||||
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
|
||||
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
|
||||
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
|
||||
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
||||
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
||||
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
||||
float32x4_t b_d = {
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
|
||||
};
|
||||
float32x4_t d = a_d * b_d;
|
||||
|
||||
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
|
||||
}
|
||||
|
||||
vst1q_f32(res_ptr + x * 4, sumf);
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
@@ -3164,6 +3239,87 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf[4];
|
||||
for (int m = 0; m < 4; m++) {
|
||||
sumf[m] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
||||
float32x4_t b_d = {
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
|
||||
};
|
||||
|
||||
int32x4_t sumi_0 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_1 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_2 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_3 = vdupq_n_s32(0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
||||
|
||||
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
||||
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
||||
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
||||
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
||||
}
|
||||
|
||||
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
||||
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
||||
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
||||
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -522,7 +522,8 @@ template<typename block_tx8>
|
||||
static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
|
||||
static_assert(
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>,
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8> ||
|
||||
std::is_same_v<block_tx8, block_mxfp4x8>,
|
||||
"Unsupported block type");
|
||||
|
||||
const int qk = QK8_0;
|
||||
@@ -580,6 +581,18 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
// Load 8 E8M0 exponents and convert to float via LUT
|
||||
// Rearranged to match changemask order: 0,4,1,5,2,6,3,7
|
||||
col_scale_f32 = _mm256_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
|
||||
}
|
||||
|
||||
// Load and convert to FP32 scale from block_q8_0
|
||||
@@ -628,7 +641,8 @@ template<typename block_tx8>
|
||||
static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
|
||||
static_assert(
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>,
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8> ||
|
||||
std::is_same_v<block_tx8, block_mxfp4x8>,
|
||||
"Unsupported block type");
|
||||
|
||||
const int qk = QK8_0;
|
||||
@@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
//TODO: simd-ify
|
||||
col_scale_f32 = _mm512_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
|
||||
}
|
||||
|
||||
// Process LHS in pairs of rows
|
||||
@@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
//TODO: simd-ify
|
||||
col_scale_f32 = _mm512_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
|
||||
}
|
||||
|
||||
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
||||
@@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
col_scale_f32 = _mm256_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
|
||||
}
|
||||
|
||||
// Process LHS in groups of four
|
||||
@@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
|
||||
std::is_same_v<block_tx8, block_q4_0x8> ||
|
||||
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
||||
col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
||||
} else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
|
||||
col_scale_f32 = _mm256_set_ps(
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
|
||||
GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
|
||||
}
|
||||
|
||||
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
||||
@@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
#if defined(__AVX2__)
|
||||
__m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
|
||||
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
||||
|
||||
gemv_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
||||
|
||||
return;
|
||||
#endif
|
||||
|
||||
ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
@@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
#if defined(__AVX2__) || defined(__AVX512F__)
|
||||
{
|
||||
__m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
|
||||
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
||||
|
||||
gemm_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
||||
|
||||
return;
|
||||
}
|
||||
#endif // defined(__AVX2__) || defined(__AVX512F__)
|
||||
|
||||
ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -1098,6 +1098,82 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert(nr == 1);
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[4];
|
||||
int sumi;
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(nr == 1);
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
int sumi;
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -1726,6 +1802,94 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][4];
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
||||
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++)
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
||||
const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
||||
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
||||
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++)
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -2510,6 +2674,121 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
|
||||
static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) {
|
||||
block_mxfp4x4 out;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
out.e[i] = in[i].e;
|
||||
}
|
||||
|
||||
const int end = QK_MXFP4 * 2 / blck_size_interleave;
|
||||
|
||||
if (blck_size_interleave == 4) {
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 4;
|
||||
int src_offset = (i / 4) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
|
||||
GGML_ASSERT(interleave_block == 4);
|
||||
|
||||
const block_mxfp4 * src = (const block_mxfp4 *)data;
|
||||
block_mxfp4x4 * dst = ( block_mxfp4x4 *)t->data;
|
||||
|
||||
block_mxfp4 dst_tmp[4];
|
||||
|
||||
int nrow = ggml_nrows(t);
|
||||
int nrows_interleaved = 4;
|
||||
int nblocks = t->ne[0] / QK_MXFP4;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_mxfp4x4(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) {
|
||||
block_mxfp4x8 out;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.e[i] = in[i].e;
|
||||
}
|
||||
|
||||
const int end = QK_MXFP4 * 4 / blck_size_interleave;
|
||||
|
||||
if (blck_size_interleave == 8) {
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
|
||||
const block_mxfp4 * src = (const block_mxfp4 *)data;
|
||||
block_mxfp4x8 * dst = ( block_mxfp4x8 *)t->data;
|
||||
|
||||
block_mxfp4 dst_tmp[8];
|
||||
|
||||
int nrow = ggml_nrows(t);
|
||||
int nrows_interleaved = 8;
|
||||
int nblocks = t->ne[0] / QK_MXFP4;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_mxfp4x8(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::repack {
|
||||
// repack
|
||||
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
||||
@@ -2569,6 +2848,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void *
|
||||
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_mxfp4, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_mxfp4, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
@@ -2636,6 +2923,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -2703,6 +2998,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
|
||||
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -3111,6 +3414,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
|
||||
|
||||
// instance for MXFP4
|
||||
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 4, 4, GGML_TYPE_Q8_0> mxfp4_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_mxfp4, 8, 8, GGML_TYPE_Q8_0> mxfp4_8x8_q8_0;
|
||||
|
||||
// instance for Q8_0
|
||||
static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
|
||||
@@ -3187,6 +3494,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &iq4_nl_4x4_q8_0;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_MXFP4) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &mxfp4_8x8_q8_0;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &mxfp4_4x4_q8_0;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q8_0) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
|
||||
@@ -97,6 +97,19 @@ struct block_iq4_nlx8 {
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
|
||||
|
||||
struct block_mxfp4x4 {
|
||||
uint8_t e[4];
|
||||
uint8_t qs[QK_MXFP4 * 2];
|
||||
};
|
||||
static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding");
|
||||
|
||||
struct block_mxfp4x8 {
|
||||
uint8_t e[8];
|
||||
uint8_t qs[QK_MXFP4 * 4];
|
||||
};
|
||||
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
|
||||
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -117,6 +130,8 @@ void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -129,6 +144,8 @@ void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -151,6 +168,8 @@ void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -163,6 +182,8 @@ void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
@@ -111,6 +111,44 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
|
||||
// Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
|
||||
|
||||
// Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
|
||||
// compile-time static_asserts even though the kernel guard prevents runtime execution.
|
||||
// nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
|
||||
return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
|
||||
}
|
||||
|
||||
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
||||
if (ampere_mma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
@@ -118,6 +156,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
|
||||
if (turing_mma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
||||
}
|
||||
if (amd_mfma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
||||
}
|
||||
if (amd_wmma_available(cc)) {
|
||||
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
||||
}
|
||||
@@ -130,6 +171,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
|
||||
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
||||
#elif defined(TURING_MMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
|
||||
#elif defined(VOLTA_MMA_AVAILABLE)
|
||||
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
@@ -205,15 +248,15 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_cols_per_thread() {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
return 1; // RDNA has a single column.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
return 1; // AMD has a single column per thread.
|
||||
#else
|
||||
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
static __host__ int get_cols_per_warp(const int cc) {
|
||||
if (turing_mma_available(cc) || amd_wmma_available(cc)) {
|
||||
if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
|
||||
return 16;
|
||||
} else {
|
||||
// Volta
|
||||
@@ -241,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
|
||||
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
||||
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
||||
if constexpr (use_cp_async) {
|
||||
@@ -252,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
|
||||
|
||||
auto load = [&] __device__ (auto n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int stride_k = warp_size >> n;
|
||||
const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
|
||||
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
const int stride_i = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
return;
|
||||
@@ -263,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||
break;
|
||||
@@ -271,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
|
||||
}
|
||||
@@ -287,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
} else {
|
||||
// TODO use ggml_cuda_memcpy_1
|
||||
auto load = [&] __device__ (const int n) {
|
||||
const int stride_k = WARP_SIZE >> n;
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
|
||||
const int stride_k = warp_size >> n;
|
||||
const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
|
||||
const int k0_stop = D2 - D2 % (1*stride_k);
|
||||
const int stride_i = WARP_SIZE / stride_k;
|
||||
const int stride_i = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
return;
|
||||
@@ -298,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
|
||||
break;
|
||||
@@ -306,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
|
||||
}
|
||||
@@ -324,18 +368,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
|
||||
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
if constexpr (use_cp_async) {
|
||||
static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
||||
static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
|
||||
static_assert(!oob_check, "OOB check incompatible with cp_async");
|
||||
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
||||
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
||||
constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
|
||||
constexpr int stride_j = nwarps * cols_per_warp;
|
||||
|
||||
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
|
||||
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
||||
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
||||
|
||||
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
||||
@@ -357,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
|
||||
}
|
||||
}
|
||||
} else if constexpr (nbatch_fa < 2*WARP_SIZE) {
|
||||
constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
|
||||
} else if constexpr (nbatch_fa < 2*warp_size) {
|
||||
constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
|
||||
constexpr int stride_j = nwarps * cols_per_warp;
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
|
||||
const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
|
||||
const int j_vram = fastmodulo(j0 + j_sram, ne01);
|
||||
|
||||
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
|
||||
const int i = threadIdx.x % (warp_size/cols_per_warp);
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
|
||||
}
|
||||
@@ -390,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
|
||||
for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
|
||||
@@ -428,7 +473,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int jt,
|
||||
const int kb0,
|
||||
const int k_VKQ_sup) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
@@ -447,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const int k_VKQ_0 = kb0 * nbatch_fa;
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
||||
#else // Volta
|
||||
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
||||
@@ -500,13 +546,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -526,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -585,12 +631,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = l % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
@@ -601,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset >= 4; offset >>= 1) {
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -611,12 +657,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = l % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
|
||||
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
||||
} else {
|
||||
@@ -649,12 +695,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = (l/2) % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
@@ -666,6 +712,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
// Values per KQ column are spread across 4 threads:
|
||||
constexpr int offset_first = 2;
|
||||
constexpr int offset_last = 1;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
|
||||
constexpr int offset_first = 32;
|
||||
constexpr int offset_last = 16;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
// Values per KQ column are spread across 2 threads:
|
||||
constexpr int offset_first = 16;
|
||||
@@ -677,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
||||
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -687,12 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
||||
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int KQ_idx = 0;
|
||||
#else
|
||||
// Turing + Volta:
|
||||
const int KQ_idx = (l/2) % 2;
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
|
||||
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
||||
} else {
|
||||
@@ -739,7 +789,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(
|
||||
KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
@@ -818,7 +868,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
||||
@@ -830,24 +880,38 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
||||
#if defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
|
||||
// Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
|
||||
// Load with transposed addressing: 4 strided half loads.
|
||||
{
|
||||
const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
|
||||
const half * xs0_h = (const half *) xs0;
|
||||
const int stride_h = stride_tile_V * 2; // stride in half units
|
||||
half * A_h = (half *) A.x;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
|
||||
}
|
||||
}
|
||||
#else
|
||||
// TODO: Try to transpose tile_V when loading gmem to smem.
|
||||
// Use mma to transpose T_A_VKQ for RDNA.
|
||||
T_A_VKQ A_trans;
|
||||
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||
mma(A, A_trans, A_identity);
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
#endif // defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
if constexpr (T_B_KQ::I == 8) {
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
} else {
|
||||
// Wide version of VKQ_C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
// AMD matrix C is column-major.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -866,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
||||
}
|
||||
}
|
||||
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
|
||||
if constexpr (nstages <= 1) {
|
||||
__syncthreads(); // Only needed if tile_K == tile_V.
|
||||
@@ -879,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
tile_Q, tile_K, tile_V, tile_mask,
|
||||
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
@@ -899,7 +963,7 @@ template<> struct mma_tile_sizes<8> {
|
||||
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
||||
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
||||
};
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
template<int ncols> struct mma_tile_sizes {
|
||||
using T_A_KQ = tile<16, 8, half2>; // row-major
|
||||
using T_B_KQ = tile<16, 8, half2>; // column-major
|
||||
@@ -944,9 +1008,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int zt_gqa,
|
||||
const int kb0_start,
|
||||
const int kb0_stop) {
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
|
||||
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
|
||||
@@ -986,7 +1051,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
||||
#else // Volta
|
||||
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
||||
@@ -1004,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
||||
const half2 scale_h2 = make_half2(scale, scale);
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
||||
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
||||
const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
|
||||
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
|
||||
const int stride_jc = WARP_SIZE / stride_k;
|
||||
const int stride_jc = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
continue;
|
||||
@@ -1015,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
||||
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
||||
break;
|
||||
@@ -1027,7 +1092,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
||||
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
||||
@@ -1035,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
@@ -1127,6 +1192,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// The partial sums are spread across 8/4 threads.
|
||||
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
||||
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// The partial sums are spread across 4 threads (wavefront64, 16 cols).
|
||||
constexpr int offset_first = 32;
|
||||
constexpr int offset_last = 16;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
// The partial sums are spread across 2 threads.
|
||||
constexpr int offset_first = 16;
|
||||
@@ -1140,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1189,7 +1258,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
||||
@@ -1249,7 +1318,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
||||
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
||||
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
|
||||
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
|
||||
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
|
||||
@@ -1283,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// Warps with threadIdx.y % np != 0 must NOT return early.
|
||||
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
||||
|
||||
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
|
||||
constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
|
||||
|
||||
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
||||
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
||||
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
|
||||
float2 meta[nmeta];
|
||||
#pragma unroll
|
||||
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
||||
meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
|
||||
meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
|
||||
}
|
||||
|
||||
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
||||
@@ -1300,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
||||
if (offset < WARP_SIZE) {
|
||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
||||
if (offset < warp_size) {
|
||||
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1318,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
||||
if (offset < WARP_SIZE) {
|
||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
||||
if (offset < warp_size) {
|
||||
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1328,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
// Write back combined meta data:
|
||||
#pragma unroll
|
||||
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
||||
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
|
||||
if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
|
||||
// Combined KQ max scale + rowsum.
|
||||
meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
||||
meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
||||
}
|
||||
}
|
||||
|
||||
// Combined KQ max + rowsum.
|
||||
static_assert(cols_per_warp <= WARP_SIZE);
|
||||
if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
||||
static_assert(cols_per_warp <= warp_size);
|
||||
if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
||||
if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
|
||||
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
||||
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
||||
}
|
||||
@@ -1388,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
|
||||
|
||||
#pragma unroll
|
||||
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
||||
const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
||||
for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
|
||||
const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
|
||||
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
|
||||
const int stride_jc = WARP_SIZE / stride_k;
|
||||
const int stride_jc = warp_size / stride_k;
|
||||
|
||||
if (k0_start == k0_stop) {
|
||||
continue;
|
||||
@@ -1399,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
#pragma unroll
|
||||
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
||||
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
||||
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
|
||||
|
||||
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
||||
break;
|
||||
@@ -1417,7 +1486,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
|
||||
#pragma unroll
|
||||
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
||||
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
||||
const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
|
||||
|
||||
float2 dstk_val = make_float2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
@@ -1453,7 +1522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
||||
jt, kb0_start, kb0_stop);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
||||
@@ -1480,7 +1549,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
||||
@@ -1508,10 +1577,18 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
||||
constexpr int nwarps = nthreads / WARP_SIZE;
|
||||
constexpr int nwarps = nthreads / warp_size;
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
@@ -1624,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
}
|
||||
|
||||
template <int DKQ, int DV, int ncols1, int ncols2>
|
||||
@@ -1644,7 +1721,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
||||
|
||||
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
|
||||
const int nwarps = nthreads / warp_size_host;
|
||||
|
||||
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
||||
|
||||
@@ -1694,7 +1772,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
}
|
||||
|
||||
launch_fattn<DV, ncols1, ncols2>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -440,6 +440,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
// hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)
|
||||
if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
// Fall through to tile kernel for small effective batch sizes.
|
||||
}
|
||||
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
|
||||
@@ -668,7 +668,7 @@ namespace ggml_cuda_mma {
|
||||
|
||||
return ret;
|
||||
}
|
||||
#elif defined(AMD_WMMA_AVAILABLE)
|
||||
#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
template <int I, int J>
|
||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||
tile<I, J/2, half2> ret;
|
||||
@@ -964,6 +964,34 @@ namespace ggml_cuda_mma {
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(RDNA4)
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
// MFMA: FP16 input, FP32 accumulate, convert back to half2.
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
|
||||
// Convert existing half2 accumulator to float for MFMA:
|
||||
floatx4_t acc_f32;
|
||||
{
|
||||
const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
acc_f32[i] = (float)acc_h[i];
|
||||
}
|
||||
}
|
||||
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
||||
acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
|
||||
|
||||
// Convert back to half2:
|
||||
{
|
||||
halfx4_t result_h;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
result_h[i] = (_Float16)acc_f32[i];
|
||||
}
|
||||
reinterpret_cast<halfx4_t&>(D.x[0]) = result_h;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
|
||||
@@ -55,7 +55,11 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
const int32_t* src2_d = (const int32_t*)src2->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
int threads = std::min((int)ne00, 768); // cols
|
||||
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
|
||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
|
||||
int threads = std::min((unsigned int)ne00, max_work_group_size); // cols
|
||||
|
||||
ctx.stream()->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
|
||||
|
||||
@@ -624,8 +624,6 @@ struct vk_device_struct {
|
||||
// floor(log2(maxComputeWorkGroupInvocations))
|
||||
uint32_t max_workgroup_size_log2 {};
|
||||
|
||||
bool flash_attention_fp16;
|
||||
|
||||
bool coopmat_support;
|
||||
bool coopmat_acc_f32_support {};
|
||||
bool coopmat_acc_f16_support {};
|
||||
@@ -2978,11 +2976,15 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
|
||||
}
|
||||
}
|
||||
|
||||
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
|
||||
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
|
||||
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
|
||||
const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
|
||||
(device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
|
||||
|
||||
uint32_t flags = (use_mask_opt ? 1 : 0) |
|
||||
(use_mask ? 2 : 0) |
|
||||
(use_logit_softcap ? 4 : 0);
|
||||
(use_logit_softcap ? 4 : 0) |
|
||||
(old_amd_windows ? 8 : 0);
|
||||
|
||||
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
|
||||
|
||||
@@ -3384,7 +3386,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
} \
|
||||
}
|
||||
|
||||
if (device->flash_attention_fp16) {
|
||||
if (device->fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
@@ -5423,10 +5425,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->mmvq_mode = 1;
|
||||
}
|
||||
|
||||
// Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613
|
||||
const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary;
|
||||
device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn;
|
||||
|
||||
return device;
|
||||
}
|
||||
|
||||
@@ -8567,7 +8565,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
const uint32_t Br = params.block_rows;
|
||||
const uint32_t Bc = params.block_cols;
|
||||
|
||||
const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
@@ -8690,7 +8688,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32;
|
||||
const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
@@ -8745,7 +8743,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
||||
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
|
||||
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc,
|
||||
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
|
||||
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
@@ -465,7 +465,14 @@ void main() {
|
||||
|
||||
if (SubGroupSize > 0) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
|
||||
if (!OLD_AMD_WINDOWS) {
|
||||
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
|
||||
} else {
|
||||
// Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
|
||||
// Shuffle full vec4 as workaround.
|
||||
// See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
|
||||
Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s));
|
||||
}
|
||||
}
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
|
||||
@@ -14,9 +14,10 @@ layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
|
||||
layout (constant_id = 10) const uint32_t Flags = 0;
|
||||
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
|
||||
|
||||
const bool USE_MASK_OPT = (Flags & 1) != 0;
|
||||
const bool MASK_ENABLE = (Flags & 2) != 0;
|
||||
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
|
||||
const bool USE_MASK_OPT = (Flags & 1) != 0;
|
||||
const bool MASK_ENABLE = (Flags & 2) != 0;
|
||||
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
|
||||
const bool OLD_AMD_WINDOWS = (Flags & 8) != 0;
|
||||
|
||||
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
ggml_add_backend_library(ggml-zendnn
|
||||
ggml-zendnn.cpp)
|
||||
|
||||
# Get ZenDNN path
|
||||
if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "")
|
||||
set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}")
|
||||
endif()
|
||||
|
||||
# Check if path is still empty or OFF
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set(ZENDNN_SHARED_LIB ON)
|
||||
set(ZENDNN_ARCHIVE_LIB OFF)
|
||||
else()
|
||||
set(ZENDNN_SHARED_LIB OFF)
|
||||
set(ZENDNN_ARCHIVE_LIB ON)
|
||||
endif()
|
||||
|
||||
# Download and build ZenDNN if not provided
|
||||
if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
||||
message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...")
|
||||
message(STATUS "This will take several minutes on first build...")
|
||||
@@ -21,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
||||
ExternalProject_Add(
|
||||
zendnn
|
||||
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
|
||||
GIT_TAG 21ce8f7879c86bf3637f707fae6f29e0951db5fe
|
||||
GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08
|
||||
PREFIX ${ZENDNN_PREFIX}
|
||||
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
|
||||
BINARY_DIR ${ZENDNN_BUILD_DIR}
|
||||
@@ -32,7 +39,9 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
||||
-DZENDNNL_BUILD_DOXYGEN=OFF
|
||||
-DZENDNNL_BUILD_GTEST=OFF
|
||||
-DZENDNNL_BUILD_BENCHDNN=OFF
|
||||
# Enable ALL matmul algorithm backends
|
||||
-DZENDNNL_DEPENDS_FBGEMM=OFF
|
||||
-DZENDNNL_LIB_BUILD_ARCHIVE=${ZENDNN_ARCHIVE_LIB}
|
||||
-DZENDNNL_LIB_BUILD_SHARED=${ZENDNN_SHARED_LIB}
|
||||
-DZENDNNL_DEPENDS_AOCLDLP=ON
|
||||
-DZENDNNL_DEPENDS_ONEDNN=ON
|
||||
-DZENDNNL_DEPENDS_LIBXSMM=ON
|
||||
@@ -45,47 +54,37 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
||||
LOG_INSTALL ON
|
||||
)
|
||||
|
||||
# Add dependency so ZenDNN builds before our library
|
||||
add_dependencies(ggml-zendnn zendnn)
|
||||
|
||||
# Set ZENDNN_ROOT to the installation directory
|
||||
set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR})
|
||||
|
||||
message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}")
|
||||
else()
|
||||
message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}")
|
||||
endif()
|
||||
|
||||
# ZenDNN headers + libs
|
||||
target_include_directories(ggml-zendnn PRIVATE
|
||||
${ZENDNN_ROOT}/zendnnl/include
|
||||
${ZENDNN_ROOT}/deps/aocldlp/include
|
||||
${ZENDNN_ROOT}/deps/aoclutils/include
|
||||
${ZENDNN_ROOT}/deps/json/include
|
||||
${ZENDNN_ROOT}/deps/libxsmm/include
|
||||
${ZENDNN_ROOT}/deps/aoclutils/include
|
||||
${ZENDNN_ROOT}/deps/aocldlp/include
|
||||
${ZENDNN_ROOT}/deps/onednn/include
|
||||
)
|
||||
${ZENDNN_ROOT}/deps/libxsmm/include)
|
||||
|
||||
target_link_directories(ggml-zendnn PRIVATE
|
||||
${ZENDNN_ROOT}/zendnnl/lib
|
||||
${ZENDNN_ROOT}/deps/aocldlp/lib
|
||||
${ZENDNN_ROOT}/deps/aoclutils/lib
|
||||
${ZENDNN_ROOT}/deps/libxsmm/lib
|
||||
${ZENDNN_ROOT}/deps/onednn/lib
|
||||
)
|
||||
if (ZENDNN_SHARED_LIB)
|
||||
target_link_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/lib)
|
||||
target_link_libraries(ggml-zendnn PRIVATE zendnnl)
|
||||
elseif (ZENDNN_ARCHIVE_LIB)
|
||||
target_link_libraries(ggml-zendnn PRIVATE
|
||||
${ZENDNN_ROOT}/zendnnl/lib/libzendnnl_archive.a
|
||||
${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libaoclutils.a
|
||||
${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libau_cpuid.a
|
||||
${ZENDNN_ROOT}/deps/aocldlp/lib/libaocl-dlp.a
|
||||
${ZENDNN_ROOT}/deps/onednn/${CMAKE_INSTALL_LIBDIR}/libdnnl.a
|
||||
${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmm.a
|
||||
${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmext.a
|
||||
${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmnoblas.a)
|
||||
endif()
|
||||
|
||||
target_link_libraries(ggml-zendnn PRIVATE
|
||||
zendnnl_archive # ZenDNN main
|
||||
aocl-dlp # AOCL libraries
|
||||
aoclutils
|
||||
au_cpuid
|
||||
dnnl # OneDNN
|
||||
xsmm # libxsmm small matrix math
|
||||
xsmmext
|
||||
xsmmnoblas
|
||||
m
|
||||
pthread
|
||||
)
|
||||
target_link_libraries(ggml-zendnn PRIVATE m pthread)
|
||||
|
||||
if (GGML_OPENMP)
|
||||
target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX)
|
||||
|
||||
@@ -41,13 +41,13 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
|
||||
const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C,
|
||||
int64_t ldc) {
|
||||
|
||||
zendnnl::lowoha::lowoha_params params;
|
||||
zendnnl::lowoha::matmul::matmul_params params;
|
||||
params.dtypes.src = ggml_to_zendnn_type<TB>();
|
||||
params.dtypes.wei = ggml_to_zendnn_type<TA>();
|
||||
params.dtypes.dst = ggml_to_zendnn_type<TC>();
|
||||
params.num_threads = ctx->n_threads;
|
||||
|
||||
zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct(
|
||||
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
|
||||
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
|
||||
n, // M: rows of B and C
|
||||
m, // N: cols of A^T and C
|
||||
@@ -63,7 +63,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
|
||||
params // params
|
||||
);
|
||||
|
||||
if (status != zendnnl::lowoha::status_t::success) {
|
||||
if (status != zendnnl::error_handling::status_t::success) {
|
||||
GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast<int>(status));
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.17.1"
|
||||
version = "0.18.0"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.34.0"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.35.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
@@ -152,7 +152,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
llama_build_and_test(test-grammar-parser.cpp)
|
||||
llama_build_and_test(test-grammar-integration.cpp)
|
||||
llama_build_and_test(test-llama-grammar.cpp)
|
||||
llama_build_and_test(test-chat.cpp)
|
||||
llama_build_and_test(test-chat.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
|
||||
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
|
||||
@@ -257,6 +257,21 @@ set(LLAMA_TEST_NAME test-mtmd-c-api)
|
||||
llama_build_and_test(test-mtmd-c-api.c)
|
||||
target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd)
|
||||
|
||||
# GGUF model data fetcher library for tests that need real model metadata
|
||||
# Only compile when cpp-httplib has SSL support (CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
if (TARGET cpp-httplib)
|
||||
get_target_property(_cpp_httplib_defs cpp-httplib INTERFACE_COMPILE_DEFINITIONS)
|
||||
if (_cpp_httplib_defs MATCHES "CPPHTTPLIB_OPENSSL_SUPPORT")
|
||||
add_library(gguf-model-data STATIC gguf-model-data.cpp)
|
||||
target_link_libraries(gguf-model-data PRIVATE common cpp-httplib)
|
||||
target_include_directories(gguf-model-data PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
add_executable(test-gguf-model-data test-gguf-model-data.cpp)
|
||||
target_link_libraries(test-gguf-model-data PRIVATE gguf-model-data common)
|
||||
llama_test(test-gguf-model-data LABEL "model")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# dummy executable - not installed
|
||||
get_filename_component(TEST_TARGET test-c.c NAME_WE)
|
||||
add_executable(${TEST_TARGET} test-c.c)
|
||||
|
||||
613
tests/gguf-model-data.cpp
Normal file
613
tests/gguf-model-data.cpp
Normal file
@@ -0,0 +1,613 @@
|
||||
// GGUF binary parser adapted from the huggingface/gguf package.
|
||||
// Reference: https://github.com/huggingface/huggingface.js
|
||||
|
||||
#include "gguf-model-data.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "http.h"
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Equivalent of RangeView
|
||||
struct gguf_buf_reader {
|
||||
const char * data;
|
||||
size_t size;
|
||||
size_t pos;
|
||||
|
||||
gguf_buf_reader(const std::vector<char> & buf) : data(buf.data()), size(buf.size()), pos(0) {}
|
||||
|
||||
bool has_n_bytes(size_t n) const {
|
||||
return pos + n <= size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool read_val(T & out) {
|
||||
if (!has_n_bytes(sizeof(T))) {
|
||||
return false;
|
||||
}
|
||||
memcpy(&out, data + pos, sizeof(T));
|
||||
pos += sizeof(T);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool read_str(std::string & out) {
|
||||
uint64_t len;
|
||||
if (!read_val(len)) {
|
||||
return false;
|
||||
}
|
||||
if (!has_n_bytes((size_t)len)) {
|
||||
return false;
|
||||
}
|
||||
out.assign(data + pos, (size_t)len);
|
||||
pos += (size_t)len;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool skip(size_t n) {
|
||||
if (!has_n_bytes(n)) {
|
||||
return false;
|
||||
}
|
||||
pos += n;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
static size_t gguf_val_type_size(int32_t vtype) {
|
||||
switch (vtype) {
|
||||
case GGUF_TYPE_UINT8: return 1;
|
||||
case GGUF_TYPE_INT8: return 1;
|
||||
case GGUF_TYPE_UINT16: return 2;
|
||||
case GGUF_TYPE_INT16: return 2;
|
||||
case GGUF_TYPE_UINT32: return 4;
|
||||
case GGUF_TYPE_INT32: return 4;
|
||||
case GGUF_TYPE_FLOAT32: return 4;
|
||||
case GGUF_TYPE_BOOL: return 1;
|
||||
case GGUF_TYPE_UINT64: return 8;
|
||||
case GGUF_TYPE_INT64: return 8;
|
||||
case GGUF_TYPE_FLOAT64: return 8;
|
||||
default: return 0; // string/array handled separately
|
||||
}
|
||||
}
|
||||
|
||||
// Equivalent of readMetadataValue(), skips unused values rather than storing
|
||||
static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
|
||||
if (vtype == GGUF_TYPE_STRING) {
|
||||
std::string tmp;
|
||||
return r.read_str(tmp);
|
||||
}
|
||||
if (vtype == GGUF_TYPE_ARRAY) {
|
||||
int32_t elem_type;
|
||||
uint64_t count;
|
||||
if (!r.read_val(elem_type)) {
|
||||
return false;
|
||||
}
|
||||
if (!r.read_val(count)) {
|
||||
return false;
|
||||
}
|
||||
if (elem_type == GGUF_TYPE_STRING) {
|
||||
for (uint64_t i = 0; i < count; i++) {
|
||||
std::string tmp;
|
||||
if (!r.read_str(tmp)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (elem_type == GGUF_TYPE_ARRAY) {
|
||||
// nested arrays - recurse
|
||||
for (uint64_t i = 0; i < count; i++) {
|
||||
if (!gguf_skip_value(r, GGUF_TYPE_ARRAY)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
size_t elem_sz = gguf_val_type_size(elem_type);
|
||||
if (elem_sz == 0) {
|
||||
return false;
|
||||
}
|
||||
return r.skip((size_t)count * elem_sz);
|
||||
}
|
||||
size_t sz = gguf_val_type_size(vtype);
|
||||
if (sz == 0) {
|
||||
return false;
|
||||
}
|
||||
return r.skip(sz);
|
||||
}
|
||||
|
||||
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
|
||||
if (vtype == GGUF_TYPE_UINT8) {
|
||||
uint8_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT8) {
|
||||
int8_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT16) {
|
||||
uint16_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT16) {
|
||||
int16_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT32) {
|
||||
uint32_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT32) {
|
||||
int32_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT64) {
|
||||
uint64_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_INT64) {
|
||||
int64_t v;
|
||||
if (!r.read_val(v)) {
|
||||
return false;
|
||||
}
|
||||
out = (uint32_t)v;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Follows the same header -> KV -> tensor parsing sequence as gguf() huggingface/gguf
|
||||
static std::optional<gguf_remote_model> gguf_parse_meta(const std::vector<char> & buf) {
|
||||
gguf_buf_reader r(buf);
|
||||
|
||||
// Header: magic(4) + version(4) + tensor_count(8) + kv_count(8) = 24 bytes minimum
|
||||
uint32_t magic_raw;
|
||||
if (!r.read_val(magic_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (memcmp(&magic_raw, "GGUF", 4) != 0) {
|
||||
fprintf(stderr, "gguf_parse_meta: invalid magic\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
uint32_t version;
|
||||
if (!r.read_val(version)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (version < 2 || version > 3) {
|
||||
fprintf(stderr, "gguf_parse_meta: unsupported version %u\n", version);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int64_t tensor_count_raw;
|
||||
int64_t kv_count_raw;
|
||||
if (!r.read_val(tensor_count_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (!r.read_val(kv_count_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
uint64_t tensor_count = (uint64_t)tensor_count_raw;
|
||||
uint64_t kv_count = (uint64_t)kv_count_raw;
|
||||
|
||||
gguf_remote_model model;
|
||||
|
||||
std::string arch_prefix;
|
||||
|
||||
// Parse KV pairs
|
||||
for (uint64_t i = 0; i < kv_count; i++) {
|
||||
std::string key;
|
||||
if (!r.read_str(key)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
int32_t vtype;
|
||||
if (!r.read_val(vtype)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (key == "general.architecture" && vtype == GGUF_TYPE_STRING) {
|
||||
if (!r.read_str(model.architecture)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
arch_prefix = model.architecture + ".";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract split.count for proper handling of split files
|
||||
if (key == "split.count") {
|
||||
uint32_t v;
|
||||
if (!gguf_read_uint32_val(r, vtype, v)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
model.n_split = (uint16_t)v;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract split.tensors.count so we can verify we have all tensors
|
||||
if (key == "split.tensors.count") {
|
||||
uint32_t v;
|
||||
if (!gguf_read_uint32_val(r, vtype, v)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
model.n_split_tensors = v;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!arch_prefix.empty()) {
|
||||
uint32_t * target = nullptr;
|
||||
|
||||
if (key == arch_prefix + "embedding_length") { target = &model.n_embd; }
|
||||
else if (key == arch_prefix + "feed_forward_length") { target = &model.n_ff; }
|
||||
else if (key == arch_prefix + "block_count") { target = &model.n_layer; }
|
||||
else if (key == arch_prefix + "attention.head_count") { target = &model.n_head; }
|
||||
else if (key == arch_prefix + "attention.head_count_kv") { target = &model.n_head_kv; }
|
||||
else if (key == arch_prefix + "expert_count") { target = &model.n_expert; }
|
||||
else if (key == arch_prefix + "attention.key_length") { target = &model.n_embd_head_k; }
|
||||
else if (key == arch_prefix + "attention.value_length") { target = &model.n_embd_head_v; }
|
||||
|
||||
if (target) {
|
||||
if (!gguf_read_uint32_val(r, vtype, *target)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!gguf_skip_value(r, vtype)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse tensor info entries
|
||||
model.tensors.reserve((size_t)tensor_count);
|
||||
for (uint64_t i = 0; i < tensor_count; i++) {
|
||||
gguf_remote_tensor t;
|
||||
|
||||
if (!r.read_str(t.name)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (!r.read_val(t.n_dims)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (t.n_dims > 4) {
|
||||
fprintf(stderr, "gguf_parse_meta: tensor '%s' has %u dims (max 4)\n", t.name.c_str(), t.n_dims);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
for (uint32_t d = 0; d < t.n_dims; d++) {
|
||||
if (!r.read_val(t.ne[d])) {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t type_raw;
|
||||
if (!r.read_val(type_raw)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
t.type = (ggml_type)type_raw;
|
||||
|
||||
uint64_t offset;
|
||||
if (!r.read_val(offset)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Infer n_vocab from token_embd.weight
|
||||
if (t.name == "token_embd.weight") {
|
||||
model.n_vocab = (uint32_t)t.ne[1];
|
||||
}
|
||||
|
||||
model.tensors.push_back(std::move(t));
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
// cache handling for local download
|
||||
static std::string get_default_cache_dir() {
|
||||
return fs_get_cache_directory() + "gguf-headers/";
|
||||
}
|
||||
|
||||
static std::string sanitize_for_path(const std::string & s) {
|
||||
std::string out = s;
|
||||
for (char & c : out) {
|
||||
if (c == '/' || c == '\\' || c == ':') {
|
||||
c = '_';
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static bool read_file(const std::string & path, std::vector<char> & out) {
|
||||
std::ifstream f(path, std::ios::binary | std::ios::ate);
|
||||
if (!f.good()) {
|
||||
return false;
|
||||
}
|
||||
auto sz = f.tellg();
|
||||
if (sz <= 0) {
|
||||
return false;
|
||||
}
|
||||
out.resize((size_t)sz);
|
||||
f.seekg(0);
|
||||
f.read(out.data(), sz);
|
||||
return f.good();
|
||||
}
|
||||
|
||||
static bool write_file(const std::string & path, const std::vector<char> & data) {
|
||||
std::ofstream f(path, std::ios::binary | std::ios::trunc);
|
||||
if (!f.good()) {
|
||||
return false;
|
||||
}
|
||||
f.write(data.data(), (std::streamsize)data.size());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// HuggingFace file auto-detection and HTTP download
|
||||
static std::pair<long, std::vector<char>> gguf_http_get(
|
||||
const std::string & url,
|
||||
const httplib::Headers & headers = {},
|
||||
int timeout_sec = 60) {
|
||||
try {
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
if (timeout_sec > 0) {
|
||||
cli.set_read_timeout(timeout_sec, 0);
|
||||
cli.set_write_timeout(timeout_sec, 0);
|
||||
}
|
||||
cli.set_connection_timeout(30, 0);
|
||||
|
||||
std::vector<char> body;
|
||||
auto res = cli.Get(parts.path, headers,
|
||||
[&](const char * data, size_t len) {
|
||||
body.insert(body.end(), data, data + len);
|
||||
return true;
|
||||
}, nullptr);
|
||||
|
||||
if (!res) {
|
||||
fprintf(stderr, "gguf_fetch: HTTP request failed for %s (error %d)\n",
|
||||
url.c_str(), (int)res.error());
|
||||
return {-1, {}};
|
||||
}
|
||||
return {res->status, std::move(body)};
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "gguf_fetch: HTTP error: %s\n", e.what());
|
||||
return {-1, {}};
|
||||
}
|
||||
}
|
||||
|
||||
// Find the filename for given repo/quant.
|
||||
// For split models, returns the first shard (the one containing "00001-of-")
|
||||
// split_prefix is set to the portion before "-00001-of-XXXXX.gguf" when a split file is found
|
||||
static std::string detect_gguf_filename(const std::string & repo, const std::string & quant,
|
||||
std::string & split_prefix) {
|
||||
split_prefix.clear();
|
||||
std::string api_url = "https://huggingface.co/api/models/" + repo;
|
||||
|
||||
auto [code, body] = gguf_http_get(api_url, {}, 30);
|
||||
if (code != 200 || body.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to query HF API for %s (HTTP %ld)\n", repo.c_str(), code);
|
||||
return "";
|
||||
}
|
||||
|
||||
nlohmann::json j;
|
||||
try {
|
||||
j = nlohmann::json::parse(body.begin(), body.end());
|
||||
} catch (...) {
|
||||
fprintf(stderr, "gguf_fetch: failed to parse HF API response\n");
|
||||
return "";
|
||||
}
|
||||
|
||||
if (!j.contains("siblings") || !j["siblings"].is_array()) {
|
||||
fprintf(stderr, "gguf_fetch: unexpected HF API response format\n");
|
||||
return "";
|
||||
}
|
||||
|
||||
std::vector<std::string> matches;
|
||||
std::string quant_upper = quant;
|
||||
for (char & c : quant_upper) { c = (char)toupper(c); }
|
||||
|
||||
for (const auto & sibling : j["siblings"]) {
|
||||
if (!sibling.contains("rfilename")) { continue; }
|
||||
std::string fname = sibling["rfilename"].get<std::string>();
|
||||
if (fname.size() < 5 || fname.substr(fname.size() - 5) != ".gguf") {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string fname_upper = fname;
|
||||
for (char & c : fname_upper) { c = (char)toupper(c); }
|
||||
if (fname_upper.find(quant_upper) != std::string::npos) {
|
||||
matches.push_back(fname);
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: no .gguf files matching '%s' in %s\n", quant.c_str(), repo.c_str());
|
||||
return "";
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end());
|
||||
|
||||
// Prefer non-split, non-supplementary file
|
||||
for (const auto & m : matches) {
|
||||
if (m.find("-of-") == std::string::npos && m.find("mmproj") == std::string::npos) {
|
||||
return m;
|
||||
}
|
||||
}
|
||||
|
||||
// Return the first shard (00001-of-) and extract the prefix
|
||||
for (const auto & m : matches) {
|
||||
auto pos = m.find("-00001-of-");
|
||||
if (pos != std::string::npos) {
|
||||
split_prefix = m.substr(0, pos);
|
||||
return m;
|
||||
}
|
||||
}
|
||||
|
||||
return matches[0];
|
||||
}
|
||||
|
||||
static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cache_path) {
|
||||
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
|
||||
|
||||
// Progressive download inspired by RangeView.fetchChunk()
|
||||
// Start at 2MB, double each time, cap at 64MB
|
||||
size_t chunk_size = 2 * 1024 * 1024;
|
||||
const size_t max_chunk = 64 * 1024 * 1024;
|
||||
|
||||
while (chunk_size <= max_chunk) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
|
||||
char range_buf[64];
|
||||
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
|
||||
httplib::Headers headers = {{"Range", range_buf}};
|
||||
|
||||
auto [code, body] = gguf_http_get(url, headers, 120);
|
||||
if (code != 200 && code != 206) {
|
||||
fprintf(stderr, "gguf_fetch: HTTP %ld fetching %s\n", code, url.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (body.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: empty response\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto result = gguf_parse_meta(body);
|
||||
if (result.has_value()) {
|
||||
write_file(cache_path, body);
|
||||
return result;
|
||||
}
|
||||
|
||||
if (code == 200) {
|
||||
fprintf(stderr, "gguf_fetch: server returned full response but metadata parse failed\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Parse failed, try larger chunk
|
||||
chunk_size *= 2;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: metadata exceeds 64MB, giving up\n");
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Try cache first, then fetch and parse a single GGUF shard.
|
||||
static std::optional<gguf_remote_model> fetch_or_cached(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cdir,
|
||||
const std::string & repo_part) {
|
||||
std::string fname_part = sanitize_for_path(filename);
|
||||
std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial";
|
||||
|
||||
{
|
||||
std::vector<char> cached;
|
||||
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
|
||||
auto result = gguf_parse_meta(cached);
|
||||
if (result.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fs_create_directory_with_parents(cdir);
|
||||
return fetch_and_parse(repo, filename, cache_path);
|
||||
}
|
||||
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
std::string split_prefix;
|
||||
std::string filename = detect_gguf_filename(repo, quant, split_prefix);
|
||||
if (filename.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto & model = model_opt.value();
|
||||
|
||||
// If the model is split across multiple files we need to fetch the remaining shards metadata
|
||||
if (model.n_split > 1) {
|
||||
if (split_prefix.empty()) {
|
||||
fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
snprintf(num_buf, sizeof(num_buf), "%05d", i);
|
||||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
model.tensors.insert(model.tensors.end(),
|
||||
std::make_move_iterator(shard->tensors.begin()),
|
||||
std::make_move_iterator(shard->tensors.end()));
|
||||
}
|
||||
|
||||
if (model.n_split_tensors > 0 && model.tensors.size() != model.n_split_tensors) {
|
||||
fprintf(stderr, "gguf_fetch: WARNING: expected %u tensors from split.tensors.count, got %zu\n",
|
||||
model.n_split_tensors, model.tensors.size());
|
||||
}
|
||||
}
|
||||
|
||||
return model_opt;
|
||||
}
|
||||
42
tests/gguf-model-data.h
Normal file
42
tests/gguf-model-data.h
Normal file
@@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct gguf_remote_tensor {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_F32;
|
||||
int64_t ne[4] = {1, 1, 1, 1}; // dimensions, unused dims = 1
|
||||
uint32_t n_dims = 0;
|
||||
};
|
||||
|
||||
struct gguf_remote_model {
|
||||
// Selected KV metadata
|
||||
std::string architecture; // general.architecture
|
||||
uint32_t n_embd = 0; // <arch>.embedding_length
|
||||
uint32_t n_ff = 0; // <arch>.feed_forward_length
|
||||
uint32_t n_vocab = 0; // inferred from token_embd.weight ne[1]
|
||||
uint32_t n_layer = 0; // <arch>.block_count
|
||||
uint32_t n_head = 0; // <arch>.attention.head_count
|
||||
uint32_t n_head_kv = 0; // <arch>.attention.head_count_kv
|
||||
uint32_t n_expert = 0; // <arch>.expert_count (0 if absent)
|
||||
uint32_t n_embd_head_k = 0; // <arch>.attention.key_length
|
||||
uint32_t n_embd_head_v = 0; // <arch>.attention.value_length
|
||||
uint16_t n_split = 0; // split.count (0 = not split)
|
||||
uint32_t n_split_tensors = 0; // split.tensors.count (0 if not split)
|
||||
|
||||
std::vector<gguf_remote_tensor> tensors;
|
||||
};
|
||||
|
||||
// Fetch model metadata from HuggingFace with local caching.
|
||||
// repo: e.g., "ggml-org/Qwen3-32B-GGUF"
|
||||
// quant: e.g., "Q8_0" -- auto-detects filename (including first shard of split models)
|
||||
// Returns nullopt if download fails or network is unavailable.
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant = "Q8_0",
|
||||
const std::string & cache_dir = ""); // empty = default
|
||||
121
tests/test-gguf-model-data.cpp
Normal file
121
tests/test-gguf-model-data.cpp
Normal file
@@ -0,0 +1,121 @@
|
||||
#include "gguf-model-data.h"
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
#define TEST_ASSERT(cond, msg) \
|
||||
do { \
|
||||
if (!(cond)) { \
|
||||
fprintf(stderr, "FAIL: %s (line %d): %s\n", #cond, __LINE__, msg); \
|
||||
return 1; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
int main() {
|
||||
fprintf(stderr, "=== test-gguf-model-data ===\n");
|
||||
|
||||
// Fetch Qwen3-0.6B Q8_0 metadata
|
||||
auto result = gguf_fetch_model_meta("ggml-org/Qwen3-0.6B-GGUF", "Q8_0");
|
||||
|
||||
if (!result.has_value()) {
|
||||
fprintf(stderr, "SKIP: could not fetch model metadata (no network or HTTP disabled)\n");
|
||||
return 0;
|
||||
}
|
||||
|
||||
const auto & model = result.value();
|
||||
|
||||
fprintf(stderr, "Architecture: %s\n", model.architecture.c_str());
|
||||
fprintf(stderr, "n_embd: %u\n", model.n_embd);
|
||||
fprintf(stderr, "n_ff: %u\n", model.n_ff);
|
||||
fprintf(stderr, "n_vocab: %u\n", model.n_vocab);
|
||||
fprintf(stderr, "n_layer: %u\n", model.n_layer);
|
||||
fprintf(stderr, "n_head: %u\n", model.n_head);
|
||||
fprintf(stderr, "n_head_kv: %u\n", model.n_head_kv);
|
||||
fprintf(stderr, "n_expert: %u\n", model.n_expert);
|
||||
fprintf(stderr, "n_embd_head_k: %u\n", model.n_embd_head_k);
|
||||
fprintf(stderr, "n_embd_head_v: %u\n", model.n_embd_head_v);
|
||||
fprintf(stderr, "tensors: %zu\n", model.tensors.size());
|
||||
|
||||
// Verify architecture
|
||||
TEST_ASSERT(model.architecture == "qwen3", "expected architecture 'qwen3'");
|
||||
|
||||
// Verify key dimensions (Qwen3-0.6B)
|
||||
TEST_ASSERT(model.n_layer == 28, "expected n_layer == 28");
|
||||
TEST_ASSERT(model.n_embd == 1024, "expected n_embd == 1024");
|
||||
TEST_ASSERT(model.n_head == 16, "expected n_head == 16");
|
||||
TEST_ASSERT(model.n_head_kv == 8, "expected n_head_kv == 8");
|
||||
TEST_ASSERT(model.n_expert == 0, "expected n_expert == 0 (not MoE)");
|
||||
TEST_ASSERT(model.n_vocab == 151936, "expected n_vocab == 151936");
|
||||
|
||||
// Verify tensor count
|
||||
TEST_ASSERT(model.tensors.size() == 311, "expected tensor count == 311");
|
||||
|
||||
// Verify known tensor names exist
|
||||
bool found_attn_q = false;
|
||||
bool found_token_embd = false;
|
||||
bool found_output_norm = false;
|
||||
for (const auto & t : model.tensors) {
|
||||
if (t.name == "blk.0.attn_q.weight") {
|
||||
found_attn_q = true;
|
||||
}
|
||||
if (t.name == "token_embd.weight") {
|
||||
found_token_embd = true;
|
||||
}
|
||||
if (t.name == "output_norm.weight") {
|
||||
found_output_norm = true;
|
||||
}
|
||||
}
|
||||
TEST_ASSERT(found_attn_q, "expected tensor 'blk.0.attn_q.weight'");
|
||||
TEST_ASSERT(found_token_embd, "expected tensor 'token_embd.weight'");
|
||||
TEST_ASSERT(found_output_norm, "expected tensor 'output_norm.weight'");
|
||||
|
||||
// Verify token_embd.weight shape
|
||||
for (const auto & t : model.tensors) {
|
||||
if (t.name == "token_embd.weight") {
|
||||
TEST_ASSERT(t.ne[0] == 1024, "expected token_embd.weight ne[0] == 1024");
|
||||
TEST_ASSERT(t.n_dims == 2, "expected token_embd.weight to be 2D");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Test that second call uses cache (just call again, it should work)
|
||||
auto result2 = gguf_fetch_model_meta("ggml-org/Qwen3-0.6B-GGUF", "Q8_0");
|
||||
TEST_ASSERT(result2.has_value(), "cached fetch should succeed");
|
||||
TEST_ASSERT(result2->tensors.size() == model.tensors.size(), "cached result should match");
|
||||
|
||||
// Test a split MoE model without specifying quant (should default to Q8_0)
|
||||
auto result3 = gguf_fetch_model_meta("ggml-org/GLM-4.6V-GGUF");
|
||||
if (!result3.has_value()) {
|
||||
fprintf(stderr, "SKIP: could not fetch GLM-4.6V metadata (no network?)\n");
|
||||
return 0;
|
||||
}
|
||||
const auto & model3 = result3.value();
|
||||
|
||||
fprintf(stderr, "Architecture: %s\n", model3.architecture.c_str());
|
||||
fprintf(stderr, "n_embd: %u\n", model3.n_embd);
|
||||
fprintf(stderr, "n_ff: %u\n", model3.n_ff);
|
||||
fprintf(stderr, "n_vocab: %u\n", model3.n_vocab);
|
||||
fprintf(stderr, "n_layer: %u\n", model3.n_layer);
|
||||
fprintf(stderr, "n_head: %u\n", model3.n_head);
|
||||
fprintf(stderr, "n_head_kv: %u\n", model3.n_head_kv);
|
||||
fprintf(stderr, "n_expert: %u\n", model3.n_expert);
|
||||
fprintf(stderr, "n_embd_head_k: %u\n", model3.n_embd_head_k);
|
||||
fprintf(stderr, "n_embd_head_v: %u\n", model3.n_embd_head_v);
|
||||
fprintf(stderr, "tensors: %zu\n", model3.tensors.size());
|
||||
|
||||
// Verify architecture
|
||||
TEST_ASSERT(model3.architecture == "glm4moe", "expected architecture 'glm4moe'");
|
||||
|
||||
// Verify key dimensions (GLM-4.6V)
|
||||
TEST_ASSERT(model3.n_layer == 46, "expected n_layer == 46");
|
||||
TEST_ASSERT(model3.n_embd == 4096, "expected n_embd == 4096");
|
||||
TEST_ASSERT(model3.n_head == 96, "expected n_head == 96");
|
||||
TEST_ASSERT(model3.n_head_kv == 8, "expected n_head_kv == 8");
|
||||
TEST_ASSERT(model3.n_expert == 128, "expected n_expert == 128 (MoE)");
|
||||
TEST_ASSERT(model3.n_vocab == 151552, "expected n_vocab == 151552");
|
||||
|
||||
// Verify tensor count
|
||||
TEST_ASSERT(model3.tensors.size() == 780, "expected tensor count == 780");
|
||||
|
||||
fprintf(stderr, "=== ALL TESTS PASSED ===\n");
|
||||
return 0;
|
||||
}
|
||||
@@ -57,8 +57,8 @@
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
|
||||
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. Takes precedence over --mmap (default: enabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
| `--numa TYPE` | attempt optimizations that help on some NUMA systems<br/>- distribute: spread execution evenly over all nodes<br/>- isolate: only spawn threads on CPUs on the node that execution started on<br/>- numactl: use the CPU map provided by numactl<br/>if run without this previously, it is recommended to drop the system page cache before using this<br/>see https://github.com/ggml-org/llama.cpp/issues/1437<br/>(env: LLAMA_ARG_NUMA) |
|
||||
| `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) |
|
||||
| `--list-devices` | print list of available devices and exit |
|
||||
@@ -109,14 +109,14 @@
|
||||
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
||||
| `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) |
|
||||
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
||||
| `--temp N` | temperature (default: 0.80) |
|
||||
| `--temp, --temperature N` | temperature (default: 0.80) |
|
||||
| `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) |
|
||||
| `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) |
|
||||
| `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) |
|
||||
| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) |
|
||||
| `--top-nsigma, --top-n-sigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) |
|
||||
| `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) |
|
||||
| `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) |
|
||||
| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) |
|
||||
| `--typical, --typical-p N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) |
|
||||
| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) |
|
||||
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) |
|
||||
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) |
|
||||
|
||||
@@ -140,8 +140,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
|
||||
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. Takes precedence over --mmap (default: enabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
| `--numa TYPE` | attempt optimizations that help on some NUMA systems<br/>- distribute: spread execution evenly over all nodes<br/>- isolate: only spawn threads on CPUs on the node that execution started on<br/>- numactl: use the CPU map provided by numactl<br/>if run without this previously, it is recommended to drop the system page cache before using this<br/>see https://github.com/ggml-org/llama.cpp/issues/1437<br/>(env: LLAMA_ARG_NUMA) |
|
||||
| `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) |
|
||||
| `--list-devices` | print list of available devices and exit |
|
||||
@@ -192,14 +192,14 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
|
||||
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
||||
| `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) |
|
||||
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
||||
| `--temp N` | temperature (default: 0.80) |
|
||||
| `--temp, --temperature N` | temperature (default: 0.80) |
|
||||
| `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) |
|
||||
| `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) |
|
||||
| `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) |
|
||||
| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) |
|
||||
| `--top-nsigma, --top-n-sigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) |
|
||||
| `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) |
|
||||
| `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) |
|
||||
| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) |
|
||||
| `--typical, --typical-p N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) |
|
||||
| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) |
|
||||
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) |
|
||||
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) |
|
||||
|
||||
@@ -74,8 +74,8 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
|
||||
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
|
||||
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. Takes precedence over --mmap (default: enabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)<br/>(env: LLAMA_ARG_MMAP) |
|
||||
| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)<br/>(env: LLAMA_ARG_DIO) |
|
||||
| `--numa TYPE` | attempt optimizations that help on some NUMA systems<br/>- distribute: spread execution evenly over all nodes<br/>- isolate: only spawn threads on CPUs on the node that execution started on<br/>- numactl: use the CPU map provided by numactl<br/>if run without this previously, it is recommended to drop the system page cache before using this<br/>see https://github.com/ggml-org/llama.cpp/issues/1437<br/>(env: LLAMA_ARG_NUMA) |
|
||||
| `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) |
|
||||
| `--list-devices` | print list of available devices and exit |
|
||||
@@ -126,14 +126,14 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
|
||||
| `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) |
|
||||
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
|
||||
| `--temp N` | temperature (default: 0.80) |
|
||||
| `--temp, --temperature N` | temperature (default: 0.80) |
|
||||
| `--top-k N` | top-k sampling (default: 40, 0 = disabled)<br/>(env: LLAMA_ARG_TOP_K) |
|
||||
| `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) |
|
||||
| `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) |
|
||||
| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) |
|
||||
| `--top-nsigma, --top-n-sigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) |
|
||||
| `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) |
|
||||
| `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) |
|
||||
| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) |
|
||||
| `--typical, --typical-p N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) |
|
||||
| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) |
|
||||
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) |
|
||||
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) |
|
||||
@@ -162,9 +162,11 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
|
||||
| Argument | Explanation |
|
||||
| -------- | ----------- |
|
||||
| `-lcs, --lookup-cache-static FNAME` | path to static lookup cache to use for lookup decoding (not updated by generation) |
|
||||
| `-lcd, --lookup-cache-dynamic FNAME` | path to dynamic lookup cache to use for lookup decoding (updated by generation) |
|
||||
| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_CTX_CHECKPOINTS) |
|
||||
| `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
|
||||
| `-kvu, --kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
|
||||
| `-kvu, --kv-unified, -no-kvu, --no-kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
|
||||
| `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) |
|
||||
| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode |
|
||||
| `-sp, --special` | special tokens output enabled (default: false) |
|
||||
@@ -182,7 +184,8 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `-otd, --override-tensor-draft <tensor name pattern>=<buffer type>,...` | override tensor buffer type for draft model |
|
||||
| `-cmoed, --cpu-moe-draft` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model<br/>(env: LLAMA_ARG_CPU_MOE_DRAFT) |
|
||||
| `-ncmoed, --n-cpu-moe-draft N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model<br/>(env: LLAMA_ARG_N_CPU_MOE_DRAFT) |
|
||||
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
|
||||
| `-a, --alias STRING` | set model name aliases, comma-separated (to be used by API)<br/>(env: LLAMA_ARG_ALIAS) |
|
||||
| `--tags STRING` | set model tags, comma-separated (informational, not used for routing)<br/>(env: LLAMA_ARG_TAGS) |
|
||||
| `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
|
||||
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
|
||||
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
|
||||
@@ -229,6 +232,10 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
|
||||
| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_MODEL_DRAFT) |
|
||||
| `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
|
||||
| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none) |
|
||||
| `--spec-ngram-size-n N` | ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: 12) |
|
||||
| `--spec-ngram-size-m N` | ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: 48) |
|
||||
| `--spec-ngram-min-hits N` | minimum hits for ngram-map speculative decoding (default: 1) |
|
||||
| `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) |
|
||||
| `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall |
|
||||
| `--embd-gemma-default` | use default EmbeddingGemma model (note: can download weights from the internet) |
|
||||
|
||||
@@ -580,6 +580,8 @@ private:
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
|
||||
std::string model_name; // name of the loaded model, to be used by API
|
||||
std::set<std::string> model_aliases; // additional names for the model
|
||||
std::set<std::string> model_tags; // informational tags
|
||||
|
||||
bool sleeping = false;
|
||||
|
||||
@@ -813,10 +815,9 @@ private:
|
||||
SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
|
||||
|
||||
if (!params_base.model_alias.empty()) {
|
||||
// user explicitly specified model name
|
||||
model_name = params_base.model_alias;
|
||||
// backward compat: use first alias as model name
|
||||
model_name = *params_base.model_alias.begin();
|
||||
} else if (!params_base.model.name.empty()) {
|
||||
// use model name in registry format (for models in cache)
|
||||
model_name = params_base.model.name;
|
||||
} else {
|
||||
// fallback: derive model name from file name
|
||||
@@ -824,6 +825,9 @@ private:
|
||||
model_name = model_path.filename().string();
|
||||
}
|
||||
|
||||
model_aliases = params_base.model_alias;
|
||||
model_tags = params_base.model_tags;
|
||||
|
||||
if (!is_resume) {
|
||||
return init();
|
||||
}
|
||||
@@ -2892,6 +2896,8 @@ server_context_meta server_context::get_meta() const {
|
||||
return server_context_meta {
|
||||
/* build_info */ build_info,
|
||||
/* model_name */ impl->model_name,
|
||||
/* model_aliases */ impl->model_aliases,
|
||||
/* model_tags */ impl->model_tags,
|
||||
/* model_path */ impl->params_base.model.path,
|
||||
/* has_mtmd */ impl->mctx != nullptr,
|
||||
/* has_inp_image */ impl->chat_params.allow_image,
|
||||
@@ -3688,6 +3694,8 @@ void server_routes::init_routes() {
|
||||
{"data", {
|
||||
{
|
||||
{"id", meta->model_name},
|
||||
{"aliases", meta->model_aliases},
|
||||
{"tags", meta->model_tags},
|
||||
{"object", "model"},
|
||||
{"created", std::time(0)},
|
||||
{"owned_by", "llamacpp"},
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-http.h"
|
||||
#include "server-task.h"
|
||||
#include "server-queue.h"
|
||||
@@ -6,12 +8,15 @@
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
struct server_context_impl; // private implementation
|
||||
|
||||
struct server_context_meta {
|
||||
std::string build_info;
|
||||
std::string model_name;
|
||||
std::set<std::string> model_aliases;
|
||||
std::set<std::string> model_tags;
|
||||
std::string model_path;
|
||||
bool has_mtmd;
|
||||
bool has_inp_image;
|
||||
|
||||
@@ -184,6 +184,51 @@ void server_models::add_model(server_model_meta && meta) {
|
||||
if (mapping.find(meta.name) != mapping.end()) {
|
||||
throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
|
||||
}
|
||||
|
||||
// check model name does not conflict with existing aliases
|
||||
for (const auto & [key, inst] : mapping) {
|
||||
if (inst.meta.aliases.count(meta.name)) {
|
||||
throw std::runtime_error(string_format("model name '%s' conflicts with alias of model '%s'",
|
||||
meta.name.c_str(), key.c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
// parse aliases from preset's --alias option (comma-separated)
|
||||
std::string alias_str;
|
||||
if (meta.preset.get_option("LLAMA_ARG_ALIAS", alias_str) && !alias_str.empty()) {
|
||||
for (auto & alias : string_split<std::string>(alias_str, ',')) {
|
||||
alias = string_strip(alias);
|
||||
if (!alias.empty()) {
|
||||
meta.aliases.insert(alias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse tags from preset's --tags option (comma-separated)
|
||||
std::string tags_str;
|
||||
if (meta.preset.get_option("LLAMA_ARG_TAGS", tags_str) && !tags_str.empty()) {
|
||||
for (auto & tag : string_split<std::string>(tags_str, ',')) {
|
||||
tag = string_strip(tag);
|
||||
if (!tag.empty()) {
|
||||
meta.tags.insert(tag);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validate aliases do not conflict with existing names or aliases
|
||||
for (const auto & alias : meta.aliases) {
|
||||
if (mapping.find(alias) != mapping.end()) {
|
||||
throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with existing model name",
|
||||
alias.c_str(), meta.name.c_str()));
|
||||
}
|
||||
for (const auto & [key, inst] : mapping) {
|
||||
if (inst.meta.aliases.count(alias)) {
|
||||
throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with alias of model '%s'",
|
||||
alias.c_str(), meta.name.c_str(), key.c_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
meta.update_args(ctx_preset, bin_path); // render args
|
||||
std::string name = meta.name;
|
||||
mapping[name] = instance_t{
|
||||
@@ -249,6 +294,8 @@ void server_models::load_models() {
|
||||
server_model_meta meta{
|
||||
/* preset */ preset.second,
|
||||
/* name */ preset.first,
|
||||
/* aliases */ {},
|
||||
/* tags */ {},
|
||||
/* port */ 0,
|
||||
/* status */ SERVER_MODEL_STATUS_UNLOADED,
|
||||
/* last_used */ 0,
|
||||
@@ -265,10 +312,28 @@ void server_models::load_models() {
|
||||
for (const auto & [name, preset] : custom_presets) {
|
||||
custom_names.insert(name);
|
||||
}
|
||||
auto join_set = [](const std::set<std::string> & s) {
|
||||
std::string result;
|
||||
for (const auto & v : s) {
|
||||
if (!result.empty()) {
|
||||
result += ", ";
|
||||
}
|
||||
result += v;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
|
||||
for (const auto & [name, inst] : mapping) {
|
||||
bool has_custom = custom_names.find(name) != custom_names.end();
|
||||
SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str());
|
||||
std::string info;
|
||||
if (!inst.meta.aliases.empty()) {
|
||||
info += " (aliases: " + join_set(inst.meta.aliases) + ")";
|
||||
}
|
||||
if (!inst.meta.tags.empty()) {
|
||||
info += " [tags: " + join_set(inst.meta.tags) + "]";
|
||||
}
|
||||
SRV_INF(" %c %s%s\n", has_custom ? '*' : ' ', name.c_str(), info.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,7 +385,15 @@ void server_models::update_meta(const std::string & name, const server_model_met
|
||||
|
||||
bool server_models::has_model(const std::string & name) {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
return mapping.find(name) != mapping.end();
|
||||
if (mapping.find(name) != mapping.end()) {
|
||||
return true;
|
||||
}
|
||||
for (const auto & [key, inst] : mapping) {
|
||||
if (inst.meta.aliases.count(name)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
|
||||
@@ -329,6 +402,11 @@ std::optional<server_model_meta> server_models::get_meta(const std::string & nam
|
||||
if (it != mapping.end()) {
|
||||
return it->second.meta;
|
||||
}
|
||||
for (const auto & [key, inst] : mapping) {
|
||||
if (inst.meta.aliases.count(name)) {
|
||||
return inst.meta;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
@@ -766,7 +844,7 @@ static void res_err(std::unique_ptr<server_http_res> & res, const json & error_d
|
||||
res->data = safe_json_to_str({{ "error", error_data }});
|
||||
}
|
||||
|
||||
static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
|
||||
static bool router_validate_model(std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
|
||||
if (name.empty()) {
|
||||
res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
|
||||
return false;
|
||||
@@ -776,6 +854,8 @@ static bool router_validate_model(const std::string & name, server_models & mode
|
||||
res_err(res, format_error_response(string_format("model '%s' not found", name.c_str()), ERROR_TYPE_INVALID_REQUEST));
|
||||
return false;
|
||||
}
|
||||
// resolve alias to canonical model name
|
||||
name = meta->name;
|
||||
if (models_autoload) {
|
||||
models.ensure_model_loaded(name);
|
||||
} else {
|
||||
@@ -847,16 +927,16 @@ void server_models_routes::init_routes() {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
json body = json::parse(req.body);
|
||||
std::string name = json_value(body, "model", std::string());
|
||||
auto model = models.get_meta(name);
|
||||
if (!model.has_value()) {
|
||||
auto meta = models.get_meta(name);
|
||||
if (!meta.has_value()) {
|
||||
res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
|
||||
return res;
|
||||
}
|
||||
if (model->status == SERVER_MODEL_STATUS_LOADED) {
|
||||
if (meta->status == SERVER_MODEL_STATUS_LOADED) {
|
||||
res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
models.load(name);
|
||||
models.load(meta->name);
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
};
|
||||
@@ -877,6 +957,7 @@ void server_models_routes::init_routes() {
|
||||
preset_copy.unset_option("LLAMA_ARG_HOST");
|
||||
preset_copy.unset_option("LLAMA_ARG_PORT");
|
||||
preset_copy.unset_option("LLAMA_ARG_ALIAS");
|
||||
preset_copy.unset_option("LLAMA_ARG_TAGS");
|
||||
status["preset"] = preset_copy.to_ini();
|
||||
}
|
||||
if (meta.is_failed()) {
|
||||
@@ -885,6 +966,8 @@ void server_models_routes::init_routes() {
|
||||
}
|
||||
models_json.push_back(json {
|
||||
{"id", meta.name},
|
||||
{"aliases", meta.aliases},
|
||||
{"tags", meta.tags},
|
||||
{"object", "model"}, // for OAI-compat
|
||||
{"owned_by", "llamacpp"}, // for OAI-compat
|
||||
{"created", t}, // for OAI-compat
|
||||
@@ -912,7 +995,7 @@ void server_models_routes::init_routes() {
|
||||
res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
models.unload(name);
|
||||
models.unload(model->name);
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
};
|
||||
|
||||
@@ -52,6 +52,8 @@ static std::string server_model_status_to_string(server_model_status status) {
|
||||
struct server_model_meta {
|
||||
common_preset preset;
|
||||
std::string name;
|
||||
std::set<std::string> aliases; // additional names that resolve to this model
|
||||
std::set<std::string> tags; // informational tags, not used for routing
|
||||
int port = 0;
|
||||
server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
|
||||
int64_t last_used = 0; // for LRU unloading
|
||||
|
||||
@@ -92,7 +92,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// for consistency between server router mode and single-model mode, we set the same model name as alias
|
||||
if (params.model_alias.empty() && !params.model.name.empty()) {
|
||||
params.model_alias = params.model.name;
|
||||
params.model_alias.insert(params.model.name);
|
||||
}
|
||||
|
||||
common_init();
|
||||
@@ -178,6 +178,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
|
||||
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
|
||||
ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai));
|
||||
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
|
||||
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
|
||||
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
|
||||
|
||||
@@ -94,3 +94,20 @@ def test_no_webui():
|
||||
server.start()
|
||||
res = requests.get(url)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_server_model_aliases_and_tags():
|
||||
global server
|
||||
server.model_alias = "tinyllama-2,fim,code"
|
||||
server.model_tags = "chat,fim,small"
|
||||
server.start()
|
||||
res = server.make_request("GET", "/models")
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["data"]) == 1
|
||||
model = res.body["data"][0]
|
||||
# aliases field must contain all aliases
|
||||
assert set(model["aliases"]) == {"tinyllama-2", "fim", "code"}
|
||||
# tags field must contain all tags
|
||||
assert set(model["tags"]) == {"chat", "fim", "small"}
|
||||
# id is derived from first alias (alphabetical order from std::set)
|
||||
assert model["id"] == "code"
|
||||
|
||||
@@ -56,6 +56,7 @@ class ServerProcess:
|
||||
|
||||
# custom options
|
||||
model_alias: str | None = None
|
||||
model_tags: str | None = None
|
||||
model_url: str | None = None
|
||||
model_file: str | None = None
|
||||
model_draft: str | None = None
|
||||
@@ -180,6 +181,8 @@ class ServerProcess:
|
||||
server_args.extend(["--pooling", self.pooling])
|
||||
if self.model_alias:
|
||||
server_args.extend(["--alias", self.model_alias])
|
||||
if self.model_tags:
|
||||
server_args.extend(["--tags", self.model_tags])
|
||||
if self.n_ctx:
|
||||
server_args.extend(["--ctx-size", self.n_ctx])
|
||||
if self.n_slots:
|
||||
|
||||
1
vendor/cpp-httplib/CMakeLists.txt
vendored
1
vendor/cpp-httplib/CMakeLists.txt
vendored
@@ -171,7 +171,6 @@ endif()
|
||||
if (CPPHTTPLIB_OPENSSL_SUPPORT)
|
||||
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) # used in server.cpp
|
||||
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
|
||||
find_library(SECURITY_FRAMEWORK Security REQUIRED)
|
||||
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
|
||||
|
||||
174
vendor/cpp-httplib/httplib.cpp
vendored
174
vendor/cpp-httplib/httplib.cpp
vendored
@@ -2571,10 +2571,46 @@ find_content_type(const std::string &path,
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
extract_media_type(const std::string &content_type,
|
||||
std::map<std::string, std::string> *params = nullptr) {
|
||||
// Extract type/subtype from Content-Type value (RFC 2045)
|
||||
// e.g. "application/json; charset=utf-8" -> "application/json"
|
||||
auto media_type = content_type;
|
||||
auto semicolon_pos = media_type.find(';');
|
||||
if (semicolon_pos != std::string::npos) {
|
||||
auto param_str = media_type.substr(semicolon_pos + 1);
|
||||
media_type = media_type.substr(0, semicolon_pos);
|
||||
|
||||
if (params) {
|
||||
// Parse parameters: key=value pairs separated by ';'
|
||||
split(param_str.data(), param_str.data() + param_str.size(), ';',
|
||||
[&](const char *b, const char *e) {
|
||||
std::string key;
|
||||
std::string val;
|
||||
split(b, e, '=', [&](const char *b2, const char *e2) {
|
||||
if (key.empty()) {
|
||||
key.assign(b2, e2);
|
||||
} else {
|
||||
val.assign(b2, e2);
|
||||
}
|
||||
});
|
||||
if (!key.empty()) {
|
||||
params->emplace(trim_copy(key), trim_double_quotes_copy(val));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Trim whitespace from media type
|
||||
return trim_copy(media_type);
|
||||
}
|
||||
|
||||
bool can_compress_content_type(const std::string &content_type) {
|
||||
using udl::operator""_t;
|
||||
|
||||
auto tag = str2tag(content_type);
|
||||
auto mime_type = extract_media_type(content_type);
|
||||
auto tag = str2tag(mime_type);
|
||||
|
||||
switch (tag) {
|
||||
case "image/svg+xml"_t:
|
||||
@@ -2586,7 +2622,7 @@ bool can_compress_content_type(const std::string &content_type) {
|
||||
|
||||
case "text/event-stream"_t: return false;
|
||||
|
||||
default: return !content_type.rfind("text/", 0);
|
||||
default: return !mime_type.rfind("text/", 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3141,7 +3177,8 @@ bool is_chunked_transfer_encoding(const Headers &headers) {
|
||||
template <typename T, typename U>
|
||||
bool prepare_content_receiver(T &x, int &status,
|
||||
ContentReceiverWithProgress receiver,
|
||||
bool decompress, U callback) {
|
||||
bool decompress, size_t payload_max_length,
|
||||
bool &exceed_payload_max_length, U callback) {
|
||||
if (decompress) {
|
||||
std::string encoding = x.get_header_value("Content-Encoding");
|
||||
std::unique_ptr<decompressor> decompressor;
|
||||
@@ -3157,12 +3194,22 @@ bool prepare_content_receiver(T &x, int &status,
|
||||
|
||||
if (decompressor) {
|
||||
if (decompressor->is_valid()) {
|
||||
size_t decompressed_size = 0;
|
||||
ContentReceiverWithProgress out = [&](const char *buf, size_t n,
|
||||
size_t off, size_t len) {
|
||||
return decompressor->decompress(buf, n,
|
||||
[&](const char *buf2, size_t n2) {
|
||||
return receiver(buf2, n2, off, len);
|
||||
});
|
||||
return decompressor->decompress(
|
||||
buf, n, [&](const char *buf2, size_t n2) {
|
||||
// Guard against zip-bomb: check
|
||||
// decompressed size against limit.
|
||||
if (payload_max_length > 0 &&
|
||||
(decompressed_size >= payload_max_length ||
|
||||
n2 > payload_max_length - decompressed_size)) {
|
||||
exceed_payload_max_length = true;
|
||||
return false;
|
||||
}
|
||||
decompressed_size += n2;
|
||||
return receiver(buf2, n2, off, len);
|
||||
});
|
||||
};
|
||||
return callback(std::move(out));
|
||||
} else {
|
||||
@@ -3183,11 +3230,14 @@ template <typename T>
|
||||
bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
|
||||
DownloadProgress progress,
|
||||
ContentReceiverWithProgress receiver, bool decompress) {
|
||||
bool exceed_payload_max_length = false;
|
||||
return prepare_content_receiver(
|
||||
x, status, std::move(receiver), decompress,
|
||||
[&](const ContentReceiverWithProgress &out) {
|
||||
x, status, std::move(receiver), decompress, payload_max_length,
|
||||
exceed_payload_max_length, [&](const ContentReceiverWithProgress &out) {
|
||||
auto ret = true;
|
||||
auto exceed_payload_max_length = false;
|
||||
// Note: exceed_payload_max_length may also be set by the decompressor
|
||||
// wrapper in prepare_content_receiver when the decompressed payload
|
||||
// size exceeds the limit.
|
||||
|
||||
if (is_chunked_transfer_encoding(x.headers)) {
|
||||
auto result = read_content_chunked(strm, x, payload_max_length, out);
|
||||
@@ -3603,12 +3653,11 @@ std::string normalize_query_string(const std::string &query) {
|
||||
|
||||
bool parse_multipart_boundary(const std::string &content_type,
|
||||
std::string &boundary) {
|
||||
auto boundary_keyword = "boundary=";
|
||||
auto pos = content_type.find(boundary_keyword);
|
||||
if (pos == std::string::npos) { return false; }
|
||||
auto end = content_type.find(';', pos);
|
||||
auto beg = pos + strlen(boundary_keyword);
|
||||
boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg));
|
||||
std::map<std::string, std::string> params;
|
||||
extract_media_type(content_type, ¶ms);
|
||||
auto it = params.find("boundary");
|
||||
if (it == params.end()) { return false; }
|
||||
boundary = it->second;
|
||||
return !boundary.empty();
|
||||
}
|
||||
|
||||
@@ -3776,11 +3825,7 @@ bool parse_accept_header(const std::string &s,
|
||||
}
|
||||
|
||||
// Remove additional parameters from media type
|
||||
auto param_pos = accept_entry.media_type.find(';');
|
||||
if (param_pos != std::string::npos) {
|
||||
accept_entry.media_type =
|
||||
trim_copy(accept_entry.media_type.substr(0, param_pos));
|
||||
}
|
||||
accept_entry.media_type = extract_media_type(accept_entry.media_type);
|
||||
|
||||
// Basic validation of media type format
|
||||
if (accept_entry.media_type.empty()) {
|
||||
@@ -5610,7 +5655,7 @@ size_t Request::get_param_value_count(const std::string &key) const {
|
||||
|
||||
bool Request::is_multipart_form_data() const {
|
||||
const auto &content_type = get_header_value("Content-Type");
|
||||
return !content_type.rfind("multipart/form-data", 0);
|
||||
return detail::extract_media_type(content_type) == "multipart/form-data";
|
||||
}
|
||||
|
||||
// Multipart FormData implementation
|
||||
@@ -7092,7 +7137,8 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) {
|
||||
return true;
|
||||
})) {
|
||||
const auto &content_type = req.get_header_value("Content-Type");
|
||||
if (!content_type.find("application/x-www-form-urlencoded")) {
|
||||
if (detail::extract_media_type(content_type) ==
|
||||
"application/x-www-form-urlencoded") {
|
||||
if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) {
|
||||
res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414?
|
||||
output_error_log(Error::ExceedMaxPayloadSize, &req);
|
||||
@@ -7479,45 +7525,63 @@ bool Server::routing(Request &req, Response &res, Stream &strm) {
|
||||
if (detail::expect_content(req)) {
|
||||
// Content reader handler
|
||||
{
|
||||
// Track whether the ContentReader was aborted due to the decompressed
|
||||
// payload exceeding `payload_max_length_`.
|
||||
// The user handler runs after the lambda returns, so we must restore the
|
||||
// 413 status if the handler overwrites it.
|
||||
bool content_reader_payload_too_large = false;
|
||||
|
||||
ContentReader reader(
|
||||
[&](ContentReceiver receiver) {
|
||||
auto result = read_content_with_content_receiver(
|
||||
strm, req, res, std::move(receiver), nullptr, nullptr);
|
||||
if (!result) { output_error_log(Error::Read, &req); }
|
||||
if (!result) {
|
||||
output_error_log(Error::Read, &req);
|
||||
if (res.status == StatusCode::PayloadTooLarge_413) {
|
||||
content_reader_payload_too_large = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
[&](FormDataHeader header, ContentReceiver receiver) {
|
||||
auto result = read_content_with_content_receiver(
|
||||
strm, req, res, nullptr, std::move(header),
|
||||
std::move(receiver));
|
||||
if (!result) { output_error_log(Error::Read, &req); }
|
||||
if (!result) {
|
||||
output_error_log(Error::Read, &req);
|
||||
if (res.status == StatusCode::PayloadTooLarge_413) {
|
||||
content_reader_payload_too_large = true;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
bool dispatched = false;
|
||||
if (req.method == "POST") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
post_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
}
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), post_handlers_for_content_reader_);
|
||||
} else if (req.method == "PUT") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
put_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
}
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), put_handlers_for_content_reader_);
|
||||
} else if (req.method == "PATCH") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
patch_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
}
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), patch_handlers_for_content_reader_);
|
||||
} else if (req.method == "DELETE") {
|
||||
if (dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader),
|
||||
delete_handlers_for_content_reader_)) {
|
||||
return true;
|
||||
dispatched = dispatch_request_for_content_reader(
|
||||
req, res, std::move(reader), delete_handlers_for_content_reader_);
|
||||
}
|
||||
|
||||
if (dispatched) {
|
||||
if (content_reader_payload_too_large) {
|
||||
// Enforce the limit: override any status the handler may have set
|
||||
// and return false so the error path sends a plain 413 response.
|
||||
res.status = StatusCode::PayloadTooLarge_413;
|
||||
res.body.clear();
|
||||
res.content_length_ = 0;
|
||||
res.content_provider_ = nullptr;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7930,16 +7994,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
routed = true;
|
||||
} else {
|
||||
res.status = StatusCode::InternalServerError_500;
|
||||
std::string val;
|
||||
auto s = e.what();
|
||||
for (size_t i = 0; s[i]; i++) {
|
||||
switch (s[i]) {
|
||||
case '\r': val += "\\r"; break;
|
||||
case '\n': val += "\\n"; break;
|
||||
default: val += s[i]; break;
|
||||
}
|
||||
}
|
||||
res.set_header("EXCEPTION_WHAT", val);
|
||||
}
|
||||
} catch (...) {
|
||||
if (exception_handler_) {
|
||||
@@ -7948,7 +8002,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
|
||||
routed = true;
|
||||
} else {
|
||||
res.status = StatusCode::InternalServerError_500;
|
||||
res.set_header("EXCEPTION_WHAT", "UNKNOWN");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -11629,8 +11682,7 @@ void SSLClient::set_session_verifier(
|
||||
session_verifier_ = std::move(verifier);
|
||||
}
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void SSLClient::enable_windows_certificate_verification(bool enabled) {
|
||||
enable_windows_cert_verification_ = enabled;
|
||||
}
|
||||
@@ -11788,8 +11840,7 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
// Additional Windows Schannel verification.
|
||||
// This provides real-time certificate validation with Windows Update
|
||||
// integration, working with both OpenSSL and MbedTLS backends.
|
||||
@@ -11835,8 +11886,7 @@ void Client::enable_server_hostname_verification(bool enabled) {
|
||||
cli_->enable_server_hostname_verification(enabled);
|
||||
}
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void Client::enable_windows_certificate_verification(bool enabled) {
|
||||
if (is_ssl_) {
|
||||
static_cast<SSLClient &>(*cli_).enable_windows_certificate_verification(
|
||||
@@ -11959,7 +12009,7 @@ bool enumerate_windows_system_certs(Callback cb) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
// Enumerate macOS Keychain certificates and call callback with DER data
|
||||
template <typename Callback>
|
||||
bool enumerate_macos_keychain_certs(Callback cb) {
|
||||
|
||||
47
vendor/cpp-httplib/httplib.h
vendored
47
vendor/cpp-httplib/httplib.h
vendored
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.34.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002200"
|
||||
#define CPPHTTPLIB_VERSION "0.35.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002300"
|
||||
|
||||
/*
|
||||
* Platform compatibility check
|
||||
@@ -357,14 +357,32 @@ using socket_t = int;
|
||||
#include <any>
|
||||
#endif
|
||||
|
||||
// On macOS with a TLS backend, enable Keychain root certificates by default
|
||||
// unless the user explicitly opts out.
|
||||
#if defined(__APPLE__) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_MACOSX_AUTOMATIC_ROOT_CERTIFICATES) && \
|
||||
(defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \
|
||||
defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || \
|
||||
defined(CPPHTTPLIB_WOLFSSL_SUPPORT))
|
||||
#ifndef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#define CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// On Windows, enable Schannel certificate verification by default
|
||||
// unless the user explicitly opts out.
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#define CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
#endif
|
||||
|
||||
#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \
|
||||
defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#if TARGET_OS_MAC
|
||||
#include <CFNetwork/CFHost.h>
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO or
|
||||
// CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
#ifdef _WIN32
|
||||
@@ -382,11 +400,11 @@ using socket_t = int;
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO
|
||||
#endif
|
||||
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/evp.h>
|
||||
@@ -430,11 +448,11 @@ using socket_t = int;
|
||||
#pragma comment(lib, "crypt32.lib")
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
|
||||
// Mbed TLS 3.x API compatibility
|
||||
#if MBEDTLS_VERSION_MAJOR >= 3
|
||||
@@ -473,11 +491,11 @@ using socket_t = int;
|
||||
#pragma comment(lib, "crypt32.lib")
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
|
||||
// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available
|
||||
@@ -2557,8 +2575,7 @@ public:
|
||||
|
||||
tls::ctx_t tls_context() const;
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void enable_windows_certificate_verification(bool enabled);
|
||||
#endif
|
||||
|
||||
@@ -2679,8 +2696,7 @@ public:
|
||||
|
||||
tls::ctx_t tls_context() const { return ctx_; }
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
void enable_windows_certificate_verification(bool enabled);
|
||||
#endif
|
||||
|
||||
@@ -2712,8 +2728,7 @@ private:
|
||||
|
||||
std::function<SSLVerifierResponse(tls::session_t)> session_verifier_;
|
||||
|
||||
#if defined(_WIN32) && \
|
||||
!defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE)
|
||||
#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE
|
||||
bool enable_windows_cert_verification_ = true;
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user