Compare commits

...

8 Commits
b1663 ... b1671

Author SHA1 Message Date
Jared Van Bortel
8fe03ffdda common : remove incorrect --model-draft default (#4568) 2023-12-21 19:55:34 +02:00
Johannes Gäßler
9154494808 CUDA: mul_mat_id always on GPU for batches >= 32 (#4553) 2023-12-21 18:42:59 +01:00
Georgi Gerganov
c083718c89 readme : update coding guidelines 2023-12-21 19:27:14 +02:00
howlger
880e352277 py : open merges file as 'utf-8' (#4566)
Otherwise, on Windows converting bling-phi-2-v0 (<https://huggingface.co/llmware/bling-phi-2-v0>) via convert-hf-to-gguf.py will fail with the following error:

```
Traceback (most recent call last):
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 1061, in <module>
    model_instance.set_vocab()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 52, in set_vocab
    self._set_vocab_gpt2()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 264, in _set_vocab_gpt2
    special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 33, in __init__
    self._load(Path(path))
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 81, in _load
    self._try_load_merges_txt(path)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 95, in _try_load_merges_txt
    for line in fp:
  File "C:\Users\User\miniconda3\envs\gguf\lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
UnicodeDecodeError: 'charmap' codec can't decode byte 0x81 in position 1415: character maps to <undefined>
```
2023-12-21 19:07:34 +02:00
bobqianic
66f35a2f48 cuda : better error message for ggml_get_rows (#4561)
* Update ggml-cuda.cu

* Update ggml-cuda.cu

* Update ggml-cuda.cu

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-12-21 19:06:44 +02:00
slaren
1398823922 cuda : replace asserts in wrong architecture checks with __trap (#4556)
* cuda : replace asserts in wrong architecture checks with __trap

* make bad_arch noreturn, remove returns
2023-12-21 18:02:30 +01:00
Johannes Gäßler
d3223afdad llama : disable per-tensor info prints on model load (#4562) 2023-12-21 18:34:17 +02:00
LoganDark
1d7a1912ce Fix access violation in ggml_cuda_free_data if tensor->extra is NULL (#4554) 2023-12-21 10:59:27 +01:00
5 changed files with 65 additions and 57 deletions

View File

@@ -982,6 +982,8 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /
- There are no strict rules for the code style, but try to follow the patterns in the code (indentation, spaces, etc.). Vertical alignment makes things more readable and easier to batch edit
- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a`
- See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions
- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices
- Matrix multiplication is unconventional: [`z = ggml_mul_mat(ctx, x, y)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means `zT = x @ yT`
### Docs

View File

@@ -920,7 +920,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -md FNAME, --model-draft FNAME\n");
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
printf(" draft model for speculative decoding\n");
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");

View File

@@ -512,6 +512,14 @@ static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
[[noreturn]]
static __device__ void bad_arch() {
printf("ERROR: ggml-cuda was compiled without support for the current GPU architecture.\n");
__trap();
(void) bad_arch; // suppress unused function warning
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@@ -1972,8 +1980,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
// second part effectively subtracts 8 from each quant value
return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2010,8 +2017,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2046,8 +2052,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
// second part effectively subtracts 16 from each quant value
return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2092,8 +2097,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2114,8 +2118,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
return d8_0*d8_1 * sumi;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2145,8 +2148,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2181,8 +2183,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
return dm2f.x*sumf_d - dm2f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2219,8 +2220,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2260,8 +2260,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
return d3 * sumf;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2286,8 +2285,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
return d3*d8 * sumi;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2320,8 +2318,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2354,8 +2351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2395,8 +2391,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
return dm5f.x*sumf_d - dm5f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2429,8 +2424,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2460,8 +2454,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
return d*sumf;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2492,8 +2485,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
return d6 * sumf_d;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -3359,8 +3351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
return dall * sumf_d - dmin * sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
#endif
@@ -3543,8 +3534,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
return d * sumf_d;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
#endif
@@ -3954,7 +3944,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_0_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4023,7 +4013,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_1_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4090,7 +4080,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_0_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4157,7 +4147,7 @@ mul_mat_q5_1(
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_1_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4224,7 +4214,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q8_0_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4291,7 +4281,7 @@ mul_mat_q2_K(
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q2_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4360,7 +4350,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q3_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4429,7 +4419,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4496,7 +4486,7 @@ mul_mat_q5_K(
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4565,7 +4555,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q6_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -6825,6 +6815,7 @@ static void ggml_cuda_op_get_rows(
break;
default:
// TODO: k-quants
fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
GGML_ASSERT(false);
break;
}
@@ -8782,8 +8773,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
// TODO: mmq/mmv support
#endif
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
const int64_t nb11 = src1->nb[1];
const int64_t nb1 = dst->nb[1];
@@ -8812,13 +8801,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
src1_row.backend = GGML_BACKEND_GPU;
dst_row.backend = GGML_BACKEND_GPU;
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
char * src1_original = (char *) src1_extra->data_device[g_main_device];
char * dst_original = (char *) dst_extra->data_device[g_main_device];
char * src1_original = src1->backend == GGML_BACKEND_CPU ?
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
char * dst_original = dst->backend == GGML_BACKEND_CPU ?
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
if (src1->ne[1] == 1) {
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8846,6 +8843,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
src1_row_extra.data_device[g_main_device] = src1_contiguous;
dst_row_extra.data_device[g_main_device] = dst_contiguous;
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
const cudaMemcpyKind dst_kind = dst->backend == GGML_BACKEND_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
@@ -8860,7 +8862,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, cudaMemcpyDeviceToDevice, stream));
nb11, src1_kind, stream));
num_src1_rows++;
}
@@ -8892,7 +8894,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
nb1, dst_kind, stream));
num_src1_rows++;
}
}
@@ -8900,6 +8902,10 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_cuda_pool_free(src1_contiguous, as_src1);
ggml_cuda_pool_free(dst_contiguous, as_dst);
}
if (dst->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaStreamSynchronize(stream));
}
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9091,7 +9097,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
}
void ggml_cuda_free_data(struct ggml_tensor * tensor) {
if (!tensor || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
return;
}
@@ -9298,7 +9304,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
return false;
}

View File

@@ -84,7 +84,7 @@ class SpecialVocab:
merges_file = path / 'merges.txt'
if not merges_file.is_file():
return False
with open(merges_file, 'r') as fp:
with open(merges_file, 'r', encoding = 'utf-8') as fp:
first_line = next(fp, '').strip()
if not first_line.startswith('#'):
fp.seek(0)

View File

@@ -2083,7 +2083,7 @@ struct llama_model_loader {
type_max = meta->type;
}
LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str());
// LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str());
}
switch (type_max) {