Compare commits

...

10 Commits

Author SHA1 Message Date
Jared Van Bortel
7c87353e61 common : remove incorrect --model-draft default 2023-12-21 12:17:12 -05:00
howlger
880e352277 py : open merges file as 'utf-8' (#4566)
Otherwise, on Windows converting bling-phi-2-v0 (<https://huggingface.co/llmware/bling-phi-2-v0>) via convert-hf-to-gguf.py will fail with the following error:

```
Traceback (most recent call last):
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 1061, in <module>
    model_instance.set_vocab()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 52, in set_vocab
    self._set_vocab_gpt2()
  File "C:\Users\User\git\gguf\convert-hf-to-gguf.py", line 264, in _set_vocab_gpt2
    special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 33, in __init__
    self._load(Path(path))
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 81, in _load
    self._try_load_merges_txt(path)
  File "C:\Users\User\git\gguf\gguf\vocab.py", line 95, in _try_load_merges_txt
    for line in fp:
  File "C:\Users\User\miniconda3\envs\gguf\lib\encodings\cp1252.py", line 23, in decode
    return codecs.charmap_decode(input,self.errors,decoding_table)[0]
UnicodeDecodeError: 'charmap' codec can't decode byte 0x81 in position 1415: character maps to <undefined>
```
2023-12-21 19:07:34 +02:00
bobqianic
66f35a2f48 cuda : better error message for ggml_get_rows (#4561)
* Update ggml-cuda.cu

* Update ggml-cuda.cu

* Update ggml-cuda.cu

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-12-21 19:06:44 +02:00
slaren
1398823922 cuda : replace asserts in wrong architecture checks with __trap (#4556)
* cuda : replace asserts in wrong architecture checks with __trap

* make bad_arch noreturn, remove returns
2023-12-21 18:02:30 +01:00
Johannes Gäßler
d3223afdad llama : disable per-tensor info prints on model load (#4562) 2023-12-21 18:34:17 +02:00
LoganDark
1d7a1912ce Fix access violation in ggml_cuda_free_data if tensor->extra is NULL (#4554) 2023-12-21 10:59:27 +01:00
Johannes Gäßler
799fc22689 CUDA: Faster Mixtral prompt processing (#4538)
* CUDA: make MoE tensors contiguous for batch size>1

* Update ggml-cuda.cu

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
2023-12-20 15:41:22 +01:00
Eric Sommerlade
328b83de23 ggml : fixed check for _MSC_VER (#4535)
Co-authored-by: Eric Sommerlade <ersomme@microsoft.com>
2023-12-19 18:17:01 +02:00
arlo-phoenix
a7aee47b98 ggml-cuda: Fix HIP build (#4528)
regression of #4490
Adds defines for two new datatypes
cublasComputeType_t, cudaDataType_t.

Currently using deprecated hipblasDatatype_t since newer ones very recent.
2023-12-18 22:33:45 +01:00
Georgi Gerganov
0e18b2e7d0 llama.swiftui : add tinyllama 1.1B F16 2023-12-18 20:17:43 +02:00
6 changed files with 147 additions and 78 deletions

View File

@@ -920,7 +920,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -md FNAME, --model-draft FNAME\n");
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
printf(" draft model for speculative decoding\n");
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");

View File

@@ -91,6 +91,15 @@ struct ContentView: View {
)
.font(.system(size: 12))
DownloadButton(
llamaState: llamaState,
modelName: "TinyLlama-1.1B (F16, 2.2 GiB)",
modelUrl: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
filename: "tinyllama-1.1b-f16.gguf"
)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
DownloadButton(
llamaState: llamaState,
modelName: "Phi-2.7B (Q4_0, 1.6 GiB)",
@@ -98,7 +107,6 @@ struct ContentView: View {
filename: "phi-2-q4_0.gguf"
)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
DownloadButton(
llamaState: llamaState,
@@ -107,6 +115,7 @@ struct ContentView: View {
filename: "phi-2-q8_0.gguf"
)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
DownloadButton(
llamaState: llamaState,
@@ -115,7 +124,6 @@ struct ContentView: View {
filename: "mistral-7b-v0.1.Q4_0.gguf"
)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
Button("Clear downloaded models") {
ContentView.cleanupModelCaches()

View File

@@ -31,6 +31,7 @@
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
@@ -40,6 +41,7 @@
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
@@ -510,6 +512,14 @@ static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
[[noreturn]]
static __device__ void bad_arch() {
printf("ERROR: ggml-cuda was compiled without support for the current GPU architecture.\n");
__trap();
(void) bad_arch; // suppress unused function warning
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@@ -1970,8 +1980,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
// second part effectively subtracts 8 from each quant value
return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2008,8 +2017,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2044,8 +2052,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
// second part effectively subtracts 16 from each quant value
return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2090,8 +2097,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2112,8 +2118,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
return d8_0*d8_1 * sumi;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2143,8 +2148,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2179,8 +2183,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
return dm2f.x*sumf_d - dm2f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2217,8 +2220,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2258,8 +2260,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
return d3 * sumf;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2284,8 +2285,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
return d3*d8 * sumi;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2318,8 +2318,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2352,8 +2351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2393,8 +2391,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
return dm5f.x*sumf_d - dm5f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2427,8 +2424,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2458,8 +2454,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
return d*sumf;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -2490,8 +2485,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
return d6 * sumf_d;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
@@ -3357,8 +3351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
return dall * sumf_d - dmin * sumf_m;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
#endif
@@ -3541,8 +3534,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
return d * sumf_d;
#else
assert(false);
return 0.0f; // only to satisfy the compiler
bad_arch();
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
#endif
@@ -3952,7 +3944,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_0_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4021,7 +4013,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_1_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4088,7 +4080,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_0_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4155,7 +4147,7 @@ mul_mat_q5_1(
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_1_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4222,7 +4214,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q8_0_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4289,7 +4281,7 @@ mul_mat_q2_K(
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q2_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4358,7 +4350,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q3_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4427,7 +4419,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q4_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4494,7 +4486,7 @@ mul_mat_q5_K(
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q5_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -4563,7 +4555,7 @@ template <bool need_check> static __global__ void
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
#else
(void) vec_dot_q6_K_q8_1_mul_mat;
assert(false);
bad_arch();
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
@@ -6823,6 +6815,7 @@ static void ggml_cuda_op_get_rows(
break;
default:
// TODO: k-quants
fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
GGML_ASSERT(false);
break;
}
@@ -7828,6 +7821,11 @@ static void ggml_cuda_set_peer_access(const int n_tokens) {
}
#ifdef NDEBUG
for (int id = 0; id < g_device_count; ++id) {
CUDA_CHECK(ggml_cuda_set_device(id));
CUDA_CHECK(cudaDeviceSynchronize());
}
for (int id = 0; id < g_device_count; ++id) {
CUDA_CHECK(ggml_cuda_set_device(id));
@@ -7879,8 +7877,6 @@ static void ggml_cuda_op_mul_mat(
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];
ggml_cuda_set_peer_access(ne11);
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
@@ -8779,16 +8775,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
const int64_t nb11 = src1->nb[1];
const int64_t nb1 = dst->nb[1];
const struct ggml_tensor * ids = src0;
const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = ((int32_t *) dst->op_params)[1];
std::vector<char> ids_host(ggml_nbytes(ids));
const cudaStream_t stream = g_cudaStreams[g_main_device][0];
if (ids->backend == GGML_BACKEND_GPU) {
const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
} else {
memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
}
@@ -8802,37 +8803,93 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
src1_row.ne[1] = 1;
dst_row.ne[1] = 1;
src1_row.nb[2] = src1_row.nb[1];
dst_row.nb[2] = dst_row.nb[1];
src1_row.nb[3] = src1_row.nb[1];
dst_row.nb[3] = dst_row.nb[1];
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
char * src1_original = (char *) src1_extra->data_device[g_main_device];
char * dst_original = (char *) dst_extra->data_device[g_main_device];
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
//CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id >= 0 && row_id < n_as);
GGML_ASSERT(row_id >= 0 && row_id < n_as);
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
src1_row.data = (char *) src1->data + i01*src1->nb[1];
src1_row_extra.data_device[g_main_device] = src1_original + i01*src1->nb[1];
src1_row.data = (char *) src1->data + i01*src1->nb[1]; // TODO why is this set?
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
dst_row.data = (char *) dst->data + i01*dst->nb[1];
dst_row_extra.data_device[g_main_device] = dst_original + i01*dst->nb[1];
dst_row.data = (char *) dst->data + i01*dst->nb[1]; // TODO why is this set?
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
}
} else {
size_t as_src1, as_dst;
char * src1_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(src1), &as_src1);
char * dst_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(dst), &as_dst);
src1_row_extra.data_device[g_main_device] = src1_contiguous;
dst_row_extra.data_device[g_main_device] = dst_contiguous;
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != row_id) {
continue;
}
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
}
if (num_src1_rows == 0) {
continue;
}
src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != row_id) {
continue;
}
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
}
}
ggml_cuda_pool_free(src1_contiguous, as_src1);
ggml_cuda_pool_free(dst_contiguous, as_dst);
}
}
@@ -9025,7 +9082,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
}
void ggml_cuda_free_data(struct ggml_tensor * tensor) {
if (!tensor || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
return;
}
@@ -9368,6 +9425,10 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
return false;
}
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT) {
ggml_cuda_set_peer_access(tensor->src[1]->ne[1]);
}
if (params->ith != 0) {
return true;
}

2
ggml.h
View File

@@ -303,7 +303,7 @@ extern "C" {
#if defined(__ARM_NEON) && defined(__CUDACC__)
typedef half ggml_fp16_t;
#elif defined(__ARM_NEON)
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
typedef __fp16 ggml_fp16_t;
#else
typedef uint16_t ggml_fp16_t;

View File

@@ -84,7 +84,7 @@ class SpecialVocab:
merges_file = path / 'merges.txt'
if not merges_file.is_file():
return False
with open(merges_file, 'r') as fp:
with open(merges_file, 'r', encoding = 'utf-8') as fp:
first_line = next(fp, '').strip()
if not first_line.startswith('#'):
fp.seek(0)

View File

@@ -2083,7 +2083,7 @@ struct llama_model_loader {
type_max = meta->type;
}
LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str());
// LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str());
}
switch (type_max) {