mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-04-09 16:17:31 +03:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9fd96283d | ||
|
|
3ba12fed0a | ||
|
|
5473949070 | ||
|
|
dcdcbad42a | ||
|
|
5764d7c6a6 | ||
|
|
87f4744a80 | ||
|
|
85d482e6b6 |
@@ -1963,7 +1963,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
params.generation_prompt = diff.right + diff.suffix;
|
||||
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <optional>
|
||||
#include <regex>
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv) {
|
||||
@@ -222,7 +223,10 @@ int main(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
base_callback_data cb_data(params, params.tensor_filter);
|
||||
std::optional<base_callback_data> cb_data;
|
||||
if (!params.save_logits) {
|
||||
cb_data.emplace(params, params.tensor_filter);
|
||||
}
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
|
||||
@@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
@@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
|
||||
smem = 32*sizeof(float)*nr0;
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
} break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
{
|
||||
nsg = N_SG_Q1_0;
|
||||
nr0 = N_R0_Q1_0;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nsg = N_SG_Q4_0;
|
||||
|
||||
@@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_Q1_0 8
|
||||
#define N_SG_Q1_0 2
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
|
||||
@@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q1_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_Q5_0 ||
|
||||
|
||||
@@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * qs = xb->qs;
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
|
||||
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
||||
const uint8_t b0 = qs[byte_offset];
|
||||
const uint8_t b1 = qs[byte_offset + 1];
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
||||
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
||||
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
||||
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
||||
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
||||
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
||||
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
||||
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
||||
|
||||
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
||||
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
||||
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
||||
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
||||
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
||||
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
||||
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
||||
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
const int base = il * 4;
|
||||
const uint8_t byte = xb->qs[base / 8];
|
||||
const int s = base % 8;
|
||||
|
||||
float4 reg_f;
|
||||
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
||||
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
||||
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
||||
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
||||
|
||||
reg = (type4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
@@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
sum_abs += fabs(src[j]);
|
||||
}
|
||||
dst.d = sum_abs / QK1_0;
|
||||
|
||||
for (int j = 0; j < QK1_0 / 8; j++) {
|
||||
dst.qs[j] = 0;
|
||||
}
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
if (src[j] >= 0.0f) {
|
||||
dst.qs[j / 8] |= (1 << (j % 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
@@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
|
||||
}
|
||||
}
|
||||
|
||||
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
|
||||
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
device const uint8_t * qs = qb_curr->qs + il / 8;
|
||||
const uint8_t b0 = qs[0];
|
||||
const uint8_t b1 = qs[1];
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
|
||||
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
|
||||
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
|
||||
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
|
||||
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
|
||||
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
|
||||
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
|
||||
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
|
||||
|
||||
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
|
||||
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
|
||||
acc += select(0.0f, yl[10], bool(b1 & 0x04));
|
||||
acc += select(0.0f, yl[11], bool(b1 & 0x08));
|
||||
acc += select(0.0f, yl[12], bool(b1 & 0x10));
|
||||
acc += select(0.0f, yl[13], bool(b1 & 0x20));
|
||||
acc += select(0.0f, yl[14], bool(b1 & 0x40));
|
||||
acc += select(0.0f, yl[15], bool(b1 & 0x80));
|
||||
|
||||
return qb_curr->d * (2.0f * acc - sumy);
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||
@@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template<int nr0, typename args_t>
|
||||
void kernel_mul_mv_q1_0_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
const int nb = args.ne00/QK1_0;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * nr0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
|
||||
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
device const block_q1_0 * ax[nr0];
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
float yl[16];
|
||||
float sumf[nr0] = {0.f};
|
||||
|
||||
const short ix = (tiisg/8);
|
||||
const short il = (tiisg%8)*16;
|
||||
|
||||
device const float * yb = y + ix*QK1_0 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
||||
float sumy = 0.f;
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
yl[i] = yb[i];
|
||||
sumy += yb[i];
|
||||
}
|
||||
|
||||
FOR_UNROLL (short row = 0; row < nr0; row++) {
|
||||
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
|
||||
}
|
||||
|
||||
yb += QK1_0 * (N_SIMDWIDTH/8);
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_q1_0_f32")]]
|
||||
kernel void kernel_mul_mv_q1_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q4_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
|
||||
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
||||
@@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
|
||||
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
||||
@@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
|
||||
|
||||
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
@@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
@@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
@@ -10070,6 +10256,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
||||
|
||||
@@ -4033,8 +4033,14 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
||||
|
||||
static ggml_backend_webgpu_reg_context ctx;
|
||||
static ggml_backend_reg reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_webgpu_reg_i,
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
|
||||
ctx.name = GGML_WEBGPU_NAME;
|
||||
ctx.device_count = 1;
|
||||
ctx.device_count = 0;
|
||||
|
||||
wgpu::InstanceDescriptor instance_descriptor{};
|
||||
std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
|
||||
@@ -4053,19 +4059,28 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct());
|
||||
ctx.webgpu_global_ctx->instance = std::move(inst);
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
if (ctx.webgpu_global_ctx->instance == nullptr) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx.webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx.webgpu_global_ctx->instance.WaitAny(
|
||||
ctx.webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
}
|
||||
|
||||
if (adapter != nullptr) {
|
||||
ctx.device_count = 1;
|
||||
}
|
||||
|
||||
static ggml_backend_reg reg = {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_webgpu_reg_i,
|
||||
/* .context = */ &ctx,
|
||||
};
|
||||
return ®
|
||||
}
|
||||
|
||||
|
||||
@@ -558,20 +558,20 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
// example: https://github.com/ggml-org/llama.cpp/pull/17548
|
||||
//
|
||||
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer)
|
||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
{LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
{LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}},
|
||||
@@ -708,9 +708,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
// altup / laurel (gemma 3n)
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
||||
@@ -2942,7 +2942,7 @@ llama_context * llama_init_from_model(
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) {
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_k);
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_k(il) % blck_size != 0) {
|
||||
@@ -2953,7 +2953,7 @@ llama_context * llama_init_from_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) {
|
||||
if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) {
|
||||
const uint32_t blck_size = ggml_blck_size(params.type_v);
|
||||
for (uint32_t il = 0; il < model->hparams.n_layer; ++il) {
|
||||
if (model->hparams.n_embd_head_v(il) % blck_size != 0) {
|
||||
|
||||
@@ -4211,13 +4211,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
|
||||
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
||||
|
||||
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0);
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
@@ -4276,9 +4277,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
if (n_embd_per_layer > 0) {
|
||||
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_per_layer * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_per_layer}, 0);
|
||||
per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0);
|
||||
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0);
|
||||
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0);
|
||||
}
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
@@ -534,9 +534,9 @@ struct llama_model {
|
||||
struct ggml_tensor * conv1d_b = nullptr;
|
||||
|
||||
// gemma3n altup
|
||||
struct ggml_tensor * tok_embd_per_layer = nullptr;
|
||||
struct ggml_tensor * altup_proj = nullptr;
|
||||
struct ggml_tensor * altup_unembd_proj = nullptr;
|
||||
struct ggml_tensor * per_layer_tok_embd = nullptr;
|
||||
struct ggml_tensor * per_layer_model_proj = nullptr;
|
||||
struct ggml_tensor * per_layer_proj_norm = nullptr;
|
||||
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
#include "models.h"
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
@@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
ggml_tensor * inp_per_layer = build_inp_per_layer();
|
||||
ggml_build_forward_expand(gf, inp_per_layer);
|
||||
|
||||
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
||||
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
|
||||
|
||||
// inpL now has only 1 altup, project it to the rest of the altups
|
||||
// these "added" altups will be concat to the last dim of inpL
|
||||
@@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
|
||||
cb(inpL, "inp_stacked", -1);
|
||||
}
|
||||
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
||||
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
||||
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
|
||||
@@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
|
||||
|
||||
// predicted value will go through self-attention and laurel
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
|
||||
cur = active_prediction;
|
||||
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens]
|
||||
cur = active_prediction;
|
||||
cb(cur, "active_prediction", il);
|
||||
|
||||
// norm
|
||||
@@ -151,12 +160,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
|
||||
ggml_tensor * first_prediction; // [n_embd, n_tokens]
|
||||
{
|
||||
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
|
||||
first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
|
||||
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
|
||||
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_gated", il);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
|
||||
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens]
|
||||
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
|
||||
cb(first_prediction, "first_prediction_scaled", il);
|
||||
|
||||
@@ -167,7 +177,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
}
|
||||
// equivalent to python code: corrected_predictions[1:] += first_prediction
|
||||
{
|
||||
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
|
||||
ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0);
|
||||
ggml_tensor * slice_rest = ggml_view_3d(
|
||||
ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd),
|
||||
ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected));
|
||||
@@ -185,7 +195,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
|
||||
// cur now has multiple altup(s), we want to merge them back to 1 altup
|
||||
{
|
||||
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
|
||||
ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens]
|
||||
// do a view to skip the first slice (active altup)
|
||||
ggml_tensor * alt_slice =
|
||||
ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd),
|
||||
@@ -197,9 +207,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const
|
||||
cb(altup_unembd, "altup_unembd", -1);
|
||||
|
||||
// equivalent to torch.mean(hidden_states, dim=0)
|
||||
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
|
||||
cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens]
|
||||
for (int i = 0; i < n_altup - 1; ++i) {
|
||||
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
|
||||
cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i));
|
||||
}
|
||||
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
|
||||
cb(cur, "unembd_merged", -1);
|
||||
@@ -235,23 +245,16 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) {
|
||||
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.per_layer_tok_embd, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
@@ -259,10 +262,10 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
} else {
|
||||
// Vision embedding path: use padding token (ID=0) embedding
|
||||
// TODO: verify if this is the correct behavior in transformers implementation
|
||||
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer
|
||||
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer
|
||||
|
||||
// Extract and dequantize padding token embedding (row 0)
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
|
||||
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
|
||||
|
||||
// Reshape to [n_embd_altup, n_layer, 1]
|
||||
@@ -275,18 +278,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// output shape: [n_embd_altup, n_tokens, n_layer]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS,
|
||||
-1); // [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * per_layer_proj;
|
||||
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
|
||||
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
||||
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
@@ -337,7 +341,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso
|
||||
// input cur shape: [n_embd, n_tokens, n_altup]
|
||||
// output shape: [n_embd, n_tokens, n_altup]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) {
|
||||
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
|
||||
ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens]
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
@@ -365,7 +369,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g
|
||||
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
||||
cb(modalities, "modalities", il);
|
||||
|
||||
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
|
||||
ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act);
|
||||
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
|
||||
cb(innovation, "innovation", il);
|
||||
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
#include "models.h"
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model),
|
||||
@@ -19,14 +26,17 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
// TODO: is causal == true correct? might need some changes
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * inp_per_layer = nullptr;
|
||||
if (model.tok_embd_per_layer) {
|
||||
inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * inp_per_layer = nullptr;
|
||||
if (model.per_layer_tok_embd) {
|
||||
inp_per_layer = build_inp_per_layer();
|
||||
ggml_build_forward_expand(gf, inp_per_layer);
|
||||
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k(il);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il));
|
||||
@@ -196,7 +206,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
|
||||
cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens]
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
|
||||
|
||||
// TODO @ngxson : improve this
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
@@ -248,34 +259,30 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
||||
ggml_tensor * llm_build_gemma4_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
GGML_ASSERT(idx < (int) x->ne[2]);
|
||||
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]),
|
||||
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
||||
}
|
||||
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_per_layer, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
|
||||
ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
|
||||
inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
|
||||
inp_per_layer = ggml_scale (ctx0, inp_per_layer, sqrtf((float) n_embd_per_layer));
|
||||
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
} else {
|
||||
// Vision embedding path: use padding token (ID=0) embedding
|
||||
// TODO: verify if this is the correct behavior in transformers implementation
|
||||
const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_per_layer * n_layer
|
||||
const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer
|
||||
|
||||
// Extract and dequantize padding token embedding (row 0)
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
|
||||
ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0);
|
||||
inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
|
||||
|
||||
// Reshape to [n_embd_per_layer, n_layer, 1]
|
||||
@@ -287,21 +294,23 @@ ggml_tensor * llm_build_gemma4_iswa::get_per_layer_inputs() {
|
||||
|
||||
// equivalent to project_per_layer_inputs() in python code
|
||||
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
||||
// inputs_embeds shape: [n_embd, n_tokens]
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from get_per_layer_inputs)
|
||||
// inp_batch shape: [n_embd, n_tokens]
|
||||
// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer)
|
||||
// output shape: [n_embd_per_layer, n_tokens, n_layer]
|
||||
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
||||
ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) {
|
||||
const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd);
|
||||
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
||||
|
||||
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
||||
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS,
|
||||
-1); // [n_embd_per_layer, n_layer, n_tokens]
|
||||
// note: this matrix multiplication will be performed in the input layer (i.e. on the CPU)
|
||||
ggml_tensor * per_layer_proj;
|
||||
per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch);
|
||||
per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale);
|
||||
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens);
|
||||
|
||||
per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
cb(per_layer_proj, "per_layer_proj", -1);
|
||||
|
||||
inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
||||
cb(inp_per_layer, "inp_per_layer", -1);
|
||||
|
||||
|
||||
@@ -256,9 +256,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
|
||||
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * calc_magnitude(ggml_tensor * x);
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
|
||||
ggml_tensor * get_per_layer_inputs();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
|
||||
|
||||
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
|
||||
ggml_tensor * build_inp_per_layer();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
|
||||
|
||||
ggml_tensor * gaussian_topk(ggml_tensor * x);
|
||||
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il);
|
||||
ggml_tensor * altup_predict(ggml_tensor * cur, int il);
|
||||
@@ -272,9 +274,10 @@ struct llm_build_gemma4_iswa : public llm_graph_context {
|
||||
const int64_t n_embd_per_layer;
|
||||
|
||||
llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx);
|
||||
ggml_tensor * get_per_layer_inputs();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer);
|
||||
|
||||
// TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER]
|
||||
ggml_tensor * build_inp_per_layer();
|
||||
ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer);
|
||||
};
|
||||
|
||||
struct llm_build_gemma_embedding : public llm_graph_context {
|
||||
|
||||
@@ -7251,6 +7251,7 @@ static const ggml_type all_types[] = {
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q1_0,
|
||||
GGML_TYPE_MXFP4, GGML_TYPE_NVFP4,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
@@ -7275,6 +7276,7 @@ static const ggml_type other_types[] = {
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_Q1_0,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
|
||||
@@ -998,6 +998,7 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
||||
auto parser = make_peg_parser(tmpls, tc.params, detailed_debug);
|
||||
if (detailed_debug) {
|
||||
LOG_DBG("Using parser: \n%s\n", parser.arena_.dump(parser.arena_.root()).c_str());
|
||||
LOG_DBG("Generation prompt: '%s'\n", parser.params_.generation_prompt.c_str());
|
||||
}
|
||||
|
||||
common_chat_msg msg_accum;
|
||||
@@ -3102,8 +3103,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Format: <minimax:tool_call><invoke name="func"><parameter name="key">value</parameter></invoke></minimax:tool_call>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug);
|
||||
tst.test("</think>Hello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist).run();
|
||||
|
||||
tst.test("I'm\nthinking</think>Hello, world!\nWhat's up?").enable_thinking(true).reasoning_format(COMMON_REASONING_FORMAT_AUTO).expect(message_assist_thoughts).run();
|
||||
|
||||
tst.test("Let's call a tool:</think><minimax:tool_call>\n<invoke name=\"empty_args\">\n</invoke>\n</minimax:tool_call>").
|
||||
enable_thinking(true).
|
||||
reasoning_format(COMMON_REASONING_FORMAT_AUTO).
|
||||
tools({ empty_args_tool }).
|
||||
expect(message_with_reasoning_and_tool_call("Let's call a tool:", "empty_args", "{}")).
|
||||
run();
|
||||
|
||||
tst.test(
|
||||
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"</think><minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter "
|
||||
"name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
@@ -3442,7 +3454,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
},
|
||||
"replaceAll": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to replace all occurences."
|
||||
"description": "Whether to replace all occurrences."
|
||||
}
|
||||
},
|
||||
"required": ["oldString", "newString"]
|
||||
|
||||
@@ -135,7 +135,7 @@ def test_completion_stream_with_openai_library_stops():
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.completions.create(
|
||||
model="davinci-002",
|
||||
prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
|
||||
prompt="System: You are helpful assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
|
||||
stop=["User:\n", "Assistant:\n"],
|
||||
max_tokens=200,
|
||||
stream=True,
|
||||
|
||||
Reference in New Issue
Block a user