Compare commits

...

3 Commits
b8799 ... b8802

Author SHA1 Message Date
Valeriy Dubov
adb541a6ad rpc : add native RDMA transport for RPC backend (RoCEv2) (#20590) 2026-04-15 16:44:02 +03:00
Xuan-Son Nguyen
80d8770804 docs: more extensive RoPE documentation [no ci] (#21953)
* more extensive ggml_rope documentation

* add more docs

* nits
2026-04-15 14:45:16 +02:00
Ruben Ortlam
8dc530b86d ci: disable test-backend-ops on Vulkan llvmpipe run and resture default timeout (#21901) 2026-04-15 10:55:21 +02:00
7 changed files with 683 additions and 42 deletions

View File

@@ -93,4 +93,5 @@ jobs:
export GGML_VK_DISABLE_F16=1
export GGML_VK_DISABLE_COOPMAT=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4800
# test-backend-ops is too slow on llvmpipe, skip it
ctest -L main -E test-backend-ops --verbose --timeout 900

View File

@@ -130,6 +130,23 @@ Note:
- Adding a model-specific API or CLI is an anti-pattern in `libmtmd`. The goal of `libmtmd` is to provide an easy-to-use, model-agnostic library for multimodal pipeline.
- In most cases, `llama-mtmd-cli` should not be modified. If a model requires a specific prompt, either let the user provide it or bake it into the Jinja chat template.
## Tips and tricks
### Working with ggml_rope_ext
PyTorch implementations usually prefer explicitly calculating `freq_cis`/`sin`/`cos` components. However, in llama.cpp, most RoPE operations can be handled via `ggml_rope_ext`, which does not require a sin/cos matrix. This saves memory while allowing the GGML RoPE kernel to be fused with other ops.
However, since `ggml_rope_ext` only provides a subset of the RoPE implementations that models use, converting models from PyTorch to llama.cpp may require some creative adaptations.
For more information about `ggml_rope_ext`, please refer to the in-code documentation in `ggml.h`.
Examples:
- `libmtmd` implements 2D RoPE with `GGML_ROPE_TYPE_NORMAL` ordering by splitting the input tensor in half, applying `ggml_rope_ext` separately to each half, then joining them back together using `ggml_concat`.
- The [Kimi-K2.5](https://github.com/ggml-org/llama.cpp/pull/19170) vision encoder uses vision RoPE with interleaved frequencies. The weights must be permuted during conversion in order to reuse the `build_rope_2d()` function.
- [Gemma 4](https://github.com/ggml-org/llama.cpp/pull/21309) uses "proportional" RoPE. We employ a trick where `rope_freqs` is set to a very large value in the last dimensions to prevent those dimensions from being rotated. See the `Gemma4Model` class in `convert_hf_to_gguf.py`.
- Some models require scaling the input position. For example, `[0, 1, 2, ...]` becomes `[0, 0.5, 1, ...]`. In this case, you can provide the scaling via `freq_scale = 0.5f`.
- Some models use learned RoPE frequencies instead of relying on `powf(freq_base, -2.0 * i / n_dims)`. In this case, you can provide the learned frequencies via the `rope_freqs` tensor (corresponding to the `c` argument in `ggml_rope_ext`), then set `freq_base = 1.0f`. An important note is that `rope_freqs` in GGML is the **inverse** (`theta = pos[i] / rope_freqs`), so you may need to invert `rope_freqs` during conversion.
## GGUF specification
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md

View File

@@ -6,9 +6,9 @@
extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 6
#define RPC_PROTO_PATCH_VERSION 1
#define RPC_PROTO_MAJOR_VERSION 4
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");

View File

@@ -1773,8 +1773,32 @@ extern "C" {
int n_dims,
int mode);
// custom RoPE
// RoPE operations with extended options
// a is the input tensor to apply RoPE to, shape [n_embd, n_head, n_token]
// b is an int32 vector with size n_token
// c is freq factors (e.g. phi3-128k), (optional)
// mode can be GGML_ROPE_TYPE_NORMAL or NEOX; for MROPE and VISION mode, use ggml_rope_multi
//
// pseudo-code for computing theta:
// for i in [0, n_dims/2):
// theta[i] = b[i] * powf(freq_base, -2.0 * i / n_dims);
// theta[i] = theta[i] / c[i]; # if c is provided, divide theta by c
// theta[i] = rope_yarn(theta[i], ...); # note: theta = theta * freq_scale is applied here
//
// other params are used by YaRN RoPE scaling, these default values will disable YaRN:
// freq_scale = 1.0f
// ext_factor = 0.0f
// attn_factor = 1.0f
// beta_fast = 0.0f
// beta_slow = 0.0f
//
// example:
// (marking: c = cos, s = sin, 0 = unrotated)
// given a single head with size = 8 --> [00000000]
// GGML_ROPE_TYPE_NORMAL n_dims = 4 --> [cscs0000]
// GGML_ROPE_TYPE_NORMAL n_dims = 8 --> [cscscscs]
// GGML_ROPE_TYPE_NEOX n_dims = 4 --> [ccss0000]
// GGML_ROPE_TYPE_NEOX n_dims = 8 --> [ccccssss]
GGML_API struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1790,6 +1814,36 @@ extern "C" {
float beta_fast,
float beta_slow);
// multi-dimensional RoPE, for Qwen-VL and similar vision models
// mode can be either VISION, MROPE, IMROPE, cannot be combined with NORMAL or NEOX
// sections specify how many dimensions to rotate in each section:
// section length is equivalent to number of cos/sin pairs, NOT the number of dims
// (i.e. sum of 4 sections are expected to be n_dims/2)
// last sections can be 0, means ignored
// all other options are identical to ggml_rope_ext
//
// important note:
// - NEOX ordering is automatically applied and cannot be disabled for MROPE and VISION
// if you need normal ordering, there are 2 methods:
// (1) split the tensor manually using ggml_view
// (2) permute the weight upon conversion
// - for VISION, n_dims must be head_size/2
//
// example M-RoPE:
// given sections = [t=4, y=2, x=2, 0]
// given a single head with size = 18 --> [000000000000000000]
// GGML_ROPE_TYPE_MROPE n_dims = 16 --> [ttttyyxxttttyyxx00] (cos/sin are applied in NEOX ordering)
// GGML_ROPE_TYPE_IMROPE n_dims = 16 --> [ttyxttyxttyxttyx00] (interleaved M-RoPE, still NEOX ordering)
// note: the theta for each dim is computed the same way as ggml_rope_ext, no matter the section
// in other words, idx used for theta: [0123456789... until n_dims/2], not reset for each section
//
// example vision RoPE:
// given sections = [y=4, x=4, 0, 0] (last 2 sections are ignored)
// given a single head with size = 8 --> [00000000]
// GGML_ROPE_TYPE_VISION n_dims = 4 --> [yyyyxxxx]
// other values of n_dims are untested and is undefined behavior
// note: unlike MROPE, the theta for each dim is computed differently for each section
// in other words, idx used for theta: [0123] for y section, then [0123] for x section
GGML_API struct ggml_tensor * ggml_rope_multi(
struct ggml_context * ctx,
struct ggml_tensor * a,

View File

@@ -7,3 +7,26 @@ ggml_add_backend_library(ggml-rpc
if (WIN32)
target_link_libraries(ggml-rpc PRIVATE ws2_32)
endif()
# RDMA auto-detection (Linux only, requires libibverbs)
if (NOT WIN32 AND NOT APPLE)
find_library(IBVERBS_LIB ibverbs)
if (IBVERBS_LIB)
option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON)
else()
option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF)
endif()
else()
set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE)
endif()
if (GGML_RPC_RDMA)
if (NOT IBVERBS_LIB)
find_library(IBVERBS_LIB ibverbs REQUIRED)
endif()
target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA)
target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB})
message(STATUS " RDMA transport enabled (auto-detected)")
else()
message(STATUS " RDMA transport disabled")
endif()

View File

@@ -3,7 +3,9 @@
#include "ggml-backend-impl.h"
#include "ggml-cpp.h"
#include <array>
#include <cinttypes>
#include <optional>
#include <string>
#include <vector>
#include <memory>
@@ -31,6 +33,14 @@
#include <filesystem>
#include <algorithm>
#ifdef GGML_RPC_RDMA
# include <infiniband/verbs.h>
# include <time.h>
# ifndef _WIN32
# include <poll.h>
# endif
#endif // GGML_RPC_RDMA
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
#define LOG_DBG(...) \
@@ -49,17 +59,116 @@ typedef int sockfd_t;
#endif
// cross-platform socket
#ifdef GGML_RPC_RDMA
static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock)
static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB
static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes
using rdma_gid_t = std::array<uint8_t, RDMA_GID_SIZE>;
struct rdma_conn {
struct ibv_context * ctx = nullptr;
struct ibv_pd * pd = nullptr;
struct ibv_cq * scq = nullptr; // send completions
struct ibv_cq * rcq = nullptr; // recv completions
struct ibv_qp * qp = nullptr;
void * tx_buf = nullptr;
struct ibv_mr * tx_mr = nullptr;
void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous
struct ibv_mr * rx_mr = nullptr;
int rx_head = 0;
uint32_t max_inline = 0;
uint8_t * rx_slot(int i) const {
return static_cast<uint8_t *>(rx_buf) + static_cast<size_t>(i) * RDMA_CHUNK;
}
bool post_rx(int i) {
struct ibv_sge sge = {};
sge.addr = (uintptr_t)rx_slot(i);
sge.length = RDMA_CHUNK;
sge.lkey = rx_mr->lkey;
struct ibv_recv_wr wr = {}, * bad = nullptr;
wr.wr_id = (uint64_t)i;
wr.sg_list = &sge;
wr.num_sge = 1;
return ibv_post_recv(qp, &wr, &bad) == 0;
}
~rdma_conn() {
if (tx_mr) ibv_dereg_mr(tx_mr);
if (rx_mr) ibv_dereg_mr(rx_mr);
free(tx_buf);
free(rx_buf);
if (qp) ibv_destroy_qp(qp);
if (scq) ibv_destroy_cq(scq);
if (rcq) ibv_destroy_cq(rcq);
if (pd) ibv_dealloc_pd(pd);
if (ctx) ibv_close_device(ctx);
}
};
// Local RDMA parameters captured during the probe phase and later consumed
// by rdma_activate() after the remote side's caps arrive via HELLO.
struct rdma_local_info {
uint32_t qpn = 0;
uint32_t psn = 0;
uint8_t gid[RDMA_GID_SIZE] = {};
uint8_t ib_port = 0;
int gid_idx = 0;
enum ibv_mtu path_mtu = IBV_MTU_1024;
};
#endif // GGML_RPC_RDMA
// conn_caps size for transport-agnostic capability exchange
static constexpr size_t RPC_CONN_CAPS_SIZE = 24;
// conn_caps RDMA layout helper
#ifdef GGML_RPC_RDMA
struct rdma_caps {
uint32_t qpn;
uint32_t psn;
uint8_t gid[RDMA_GID_SIZE];
};
static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size");
#endif // GGML_RPC_RDMA
// Forward declarations for transport function pointers
struct socket_t;
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size);
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size);
struct socket_t {
sockfd_t fd;
bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl;
bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl;
#ifdef GGML_RPC_RDMA
std::unique_ptr<rdma_conn> rdma;
rdma_local_info rdma_local = {};
#endif // GGML_RPC_RDMA
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() {
#ifdef GGML_RPC_RDMA
rdma.reset();
#endif // GGML_RPC_RDMA
LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32
closesocket(this->fd);
if (fd != INVALID_SOCKET) closesocket(this->fd);
#else
close(this->fd);
if (fd >= 0) close(this->fd);
#endif
}
// Advertise local transport capabilities into conn_caps.
// May probe RDMA and store the probe on this socket for update_caps.
void get_caps(uint8_t * caps);
// Activate transport upgrade based on remote conn_caps using the probe
// previously stored by get_caps.
void update_caps(const uint8_t * remote_caps);
};
// macro for nicer error messages on server crash
@@ -115,10 +224,16 @@ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
struct rpc_msg_hello_req {
uint8_t conn_caps[RPC_CONN_CAPS_SIZE];
};
struct rpc_msg_hello_rsp {
uint8_t major;
uint8_t minor;
uint8_t patch;
uint8_t padding;
uint8_t conn_caps[RPC_CONN_CAPS_SIZE];
};
struct rpc_msg_device_count_rsp {
@@ -414,27 +529,414 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return true;
}
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sockfd, msg, msg_size);
// TCP transport implementations (for function-pointer dispatch)
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) {
return send_data(sock->fd, data, size);
}
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) {
return recv_data(sock->fd, data, size);
}
// RDMA transport (performance-optimized, auto-negotiated)
#ifdef GGML_RPC_RDMA
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size);
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size);
static inline bool tcp_peer_closed(int fd) {
if (fd < 0) return false;
#ifndef _WIN32
struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 };
int r = poll(&pfd, 1, 0);
return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP));
#else
return false;
#endif
}
static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) {
for (uint64_t s = 0; ; s++) {
int n = ibv_poll_cq(cq, 1, wc);
if (n > 0) {
if (wc->status != IBV_WC_SUCCESS) {
GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n",
wc->status, ibv_wc_status_str(wc->status), wc->vendor_err);
}
return wc->status == IBV_WC_SUCCESS;
}
if (n < 0) return false;
if ((s & 0xFFFFF) == 0 && s > 0) {
if (tcp_peer_closed(tcp_fd)) {
return false;
}
}
}
}
static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) {
const uint8_t * src = (const uint8_t *)data;
size_t rem = size;
while (rem > 0) {
size_t chunk = std::min(rem, RDMA_CHUNK);
struct ibv_sge sge = {};
struct ibv_send_wr wr = {}, * bad = nullptr;
wr.opcode = IBV_WR_SEND;
wr.sg_list = &sge;
wr.num_sge = 1;
if (chunk <= c->max_inline) {
sge.addr = (uintptr_t)src;
sge.length = chunk;
wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE;
} else {
memcpy(c->tx_buf, src, chunk);
sge.addr = (uintptr_t)c->tx_buf;
sge.length = chunk;
sge.lkey = c->tx_mr->lkey;
wr.send_flags = IBV_SEND_SIGNALED;
}
if (ibv_post_send(c->qp, &wr, &bad) != 0) return false;
struct ibv_wc wc;
if (!rdma_poll(c->scq, &wc, tcp_fd)) return false;
src += chunk;
rem -= chunk;
}
return true;
}
static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) {
uint8_t * dst = (uint8_t *)data;
size_t rem = size;
while (rem > 0) {
struct ibv_wc wc;
if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false;
int slot = (int)wc.wr_id;
size_t got = wc.byte_len;
memcpy(dst, c->rx_slot(slot), got);
if (!c->post_rx(slot)) return false;
dst += got;
rem -= got;
}
return true;
}
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) {
return rdma_send(sock->rdma.get(), data, size, sock->fd);
}
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) {
return rdma_recv(sock->rdma.get(), data, size, sock->fd);
}
// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address.
// Used to match the socket's local IP against the kernel's GID table so that
// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly:
// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4)
// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape)
// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is
// Returns std::nullopt on unsupported family or getsockname failure.
static std::optional<rdma_gid_t> rdma_build_target_gid(sockfd_t tcp_fd) {
sockaddr_storage addr = {};
socklen_t addr_len = sizeof(addr);
if (getsockname(tcp_fd, reinterpret_cast<sockaddr *>(&addr), &addr_len) != 0) {
return std::nullopt;
}
rdma_gid_t target = {};
if (addr.ss_family == AF_INET) {
const auto * a = reinterpret_cast<const sockaddr_in *>(&addr);
target[10] = 0xff;
target[11] = 0xff;
memcpy(&target[12], &a->sin_addr, 4);
return target;
}
if (addr.ss_family == AF_INET6) {
const auto * a = reinterpret_cast<const sockaddr_in6 *>(&addr);
memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE);
return target;
}
return std::nullopt;
}
static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) {
const char * dev_env = std::getenv("GGML_RDMA_DEV");
const char * gid_env = std::getenv("GGML_RDMA_GID");
auto target_gid = rdma_build_target_gid(tcp_fd);
if (!target_gid) {
return nullptr;
}
const uint8_t ib_port = 1;
int num_devs = 0;
ibv_device ** devs = ibv_get_device_list(&num_devs);
if (!devs || num_devs == 0) return nullptr;
ibv_context * ibctx = nullptr;
const char * matched_dev = nullptr;
int gid_idx = gid_env ? atoi(gid_env) : -1;
int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB
for (int d = 0; d < num_devs; d++) {
const char * dn = ibv_get_device_name(devs[d]);
if (dev_env && strcmp(dev_env, dn) != 0) continue;
ibv_context * ctx = ibv_open_device(devs[d]);
if (!ctx) continue;
ibv_port_attr pa;
if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; }
int found_gid = gid_idx;
int found_version = IBV_GID_TYPE_IB;
if (found_gid < 0) {
// Find a GID on this port whose bytes equal the local TCP address
// (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1
// (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths
// are avoided. ibv_query_gid_ex returns gid+type in one call.
int v2_idx = -1;
int v1_idx = -1;
for (int i = 0; i < pa.gid_tbl_len; i++) {
ibv_gid_entry entry = {};
if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue;
if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue;
if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) {
v2_idx = i;
} else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) {
v1_idx = i;
}
}
if (v2_idx >= 0) {
found_gid = v2_idx;
found_version = IBV_GID_TYPE_ROCE_V2;
} else if (v1_idx >= 0) {
found_gid = v1_idx;
found_version = IBV_GID_TYPE_ROCE_V1;
}
} else {
// Explicit GID index from GGML_RDMA_GID — fetch its type for logging.
ibv_gid_entry entry = {};
if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) {
found_version = entry.gid_type;
}
}
if (found_gid >= 0) {
ibctx = ctx;
gid_idx = found_gid;
gid_version = found_version;
matched_dev = dn;
out->path_mtu = pa.active_mtu;
break;
}
ibv_close_device(ctx);
}
ibv_free_device_list(devs);
if (!ibctx) return nullptr;
out->ib_port = ib_port;
out->gid_idx = gid_idx;
// unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(),
// so each failure path is a plain `return nullptr;`.
auto c = std::make_unique<rdma_conn>();
c->ctx = ibctx;
c->pd = ibv_alloc_pd(ibctx);
if (!c->pd) return nullptr;
c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0);
c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0);
if (!c->scq || !c->rcq) return nullptr;
ibv_qp_init_attr qia = {};
qia.send_cq = c->scq;
qia.recv_cq = c->rcq;
qia.qp_type = IBV_QPT_RC;
qia.cap.max_send_wr = 4;
qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4;
qia.cap.max_send_sge = 1;
qia.cap.max_recv_sge = 1;
qia.cap.max_inline_data = 256;
c->qp = ibv_create_qp(c->pd, &qia);
if (!c->qp) return nullptr;
c->max_inline = qia.cap.max_inline_data;
c->tx_buf = aligned_alloc(4096, RDMA_CHUNK);
c->rx_buf = aligned_alloc(4096, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK);
if (!c->tx_buf || !c->rx_buf) return nullptr;
c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE);
c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!c->tx_mr || !c->rx_mr) return nullptr;
ibv_gid local_gid;
if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr;
out->qpn = c->qp->qp_num;
out->psn = c->qp->qp_num & 0xffffff;
memcpy(out->gid, &local_gid, RDMA_GID_SIZE);
const char * ver_str = "";
if (gid_version == IBV_GID_TYPE_ROCE_V2) {
ver_str = " RoCEv2";
} else if (gid_version == IBV_GID_TYPE_ROCE_V1) {
ver_str = " RoCEv1";
}
GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n",
matched_dev, gid_idx, ver_str, out->qpn, c->max_inline);
return c.release();
}
// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS.
// On success, the connection is live and ready for rdma_send/rdma_recv.
static bool rdma_activate(rdma_conn * c, const rdma_local_info * local,
uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) {
// RESET -> INIT
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_INIT;
a.port_num = local->ib_port;
a.pkey_index = 0;
a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
return false;
}
}
for (int i = 0; i < RDMA_RX_DEPTH; i++) {
if (!c->post_rx(i)) return false;
}
// INIT -> RTR
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTR;
a.path_mtu = local->path_mtu;
a.dest_qp_num = remote_qpn;
a.rq_psn = remote_psn;
a.max_dest_rd_atomic = 1;
a.min_rnr_timer = 1;
a.ah_attr.is_global = 1;
memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE);
a.ah_attr.grh.hop_limit = 1;
a.ah_attr.grh.sgid_index = local->gid_idx;
a.ah_attr.dlid = 0;
a.ah_attr.port_num = local->ib_port;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) {
return false;
}
}
// RTR -> RTS
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTS;
a.timeout = 14;
a.retry_cnt = 7;
a.rnr_retry = 7;
a.sq_psn = local->psn;
a.max_rd_atomic = 1;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY |
IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) {
return false;
}
}
GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n",
local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH);
return true;
}
#endif // GGML_RPC_RDMA
// ---------------------------------------------------------------------------
// socket_t transport capability methods
// ---------------------------------------------------------------------------
void socket_t::get_caps(uint8_t * caps) {
memset(caps, 0, RPC_CONN_CAPS_SIZE);
#ifdef GGML_RPC_RDMA
rdma_local = {};
rdma.reset(rdma_probe(fd, &rdma_local));
if (rdma) {
rdma_caps rc = {};
rc.qpn = rdma_local.qpn;
rc.psn = rdma_local.psn;
memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE);
memcpy(caps, &rc, sizeof(rc));
}
#endif // GGML_RPC_RDMA
}
void socket_t::update_caps(const uint8_t * remote_caps) {
#ifdef GGML_RPC_RDMA
if (!rdma) {
return;
}
rdma_caps rc = {};
memcpy(&rc, remote_caps, sizeof(rc));
if (rc.qpn == 0) {
rdma.reset();
return;
}
if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) {
fn_send = rdma_send_impl;
fn_recv = rdma_recv_impl;
} else {
GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n");
rdma.reset();
}
#else
(void)remote_caps;
#endif // GGML_RPC_RDMA
}
// unified transport dispatch (via function pointers)
static bool send_data(socket_t * sock, const void * data, size_t size) {
return sock->fn_send(sock, data, size);
}
static bool recv_data(socket_t * sock, void * data, size_t size) {
return sock->fn_recv(sock, data, size);
}
static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) {
if (!send_data(sock, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sock, msg, msg_size);
}
static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
if (size != msg_size) {
return false;
}
return recv_data(sockfd, msg, msg_size);
return recv_data(sock, msg, msg_size);
}
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
static bool recv_msg(socket_t * sock, std::vector<uint8_t> & input) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
try {
@@ -443,7 +945,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
return false;
}
return recv_data(sockfd, input.data(), size);
return recv_data(sock, input.data(), size);
}
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
@@ -452,7 +954,11 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
return false;
}
host = endpoint.substr(0, pos);
port = std::stoi(endpoint.substr(pos + 1));
try {
port = std::stoi(endpoint.substr(pos + 1));
} catch (...) {
return false;
}
return true;
}
@@ -460,13 +966,13 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
// No response
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
uint8_t cmd_byte = cmd;
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) {
return false;
}
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
if (!send_data(sock.get(), &input_size, sizeof(input_size))) {
return false;
}
if (!send_data(sock->fd, input, input_size)) {
if (!send_data(sock.get(), input, input_size)) {
return false;
}
return true;
@@ -478,16 +984,14 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
return false;
}
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
// even if we do, we can skip sending output_size from the server for commands with known output size
uint64_t out_size;
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
if (!recv_data(sock.get(), &out_size, sizeof(out_size))) {
return false;
}
if (out_size != output_size) {
return false;
}
if (!recv_data(sock->fd, output, output_size)) {
if (!recv_data(sock.get(), output, output_size)) {
return false;
}
return true;
@@ -495,17 +999,25 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
// RPC client-side implementation
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
// Performs HELLO handshake with transport auto-negotiation.
// Advertises local capabilities via conn_caps; if the server responds with
// matching capabilities, the socket is upgraded transparently.
static bool negotiate_hello(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_req request = {};
rpc_msg_hello_rsp response = {};
sock->get_caps(request.conn_caps);
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n",
response.major, response.minor, response.patch);
return false;
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
sock->update_caps(response.conn_caps);
return true;
}
@@ -527,6 +1039,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
return nullptr;
}
#ifdef _WIN32
if (!initialized) {
WSADATA wsaData;
@@ -543,10 +1056,10 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
if (sock == nullptr) {
return nullptr;
}
if (!check_server_version(sock)) {
if (!negotiate_hello(sock)) {
return nullptr;
}
LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str());
sockets[endpoint] = sock;
return sock;
}
@@ -1597,25 +2110,46 @@ rpc_server::~rpc_server() {
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd) {
socket_t * sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
}
// the first command sent by the client must be HELLO
if (cmd != RPC_CMD_HELLO) {
GGML_LOG_ERROR("Expected HELLO command, update client\n");
return;
}
if (!recv_msg(sockfd, nullptr, 0)) {
// Read input_size and validate protocol version
uint64_t hello_input_size;
if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) {
return;
}
rpc_msg_hello_rsp response;
server.hello(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
if (hello_input_size != sizeof(rpc_msg_hello_req)) {
GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n",
(size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION);
return;
}
rpc_msg_hello_req req = {};
if (!recv_data(sockfd, &req, sizeof(req))) {
return;
}
rpc_msg_hello_rsp rsp = {};
server.hello(rsp);
// Advertise server transport capabilities based on client's caps
sockfd->get_caps(rsp.conn_caps);
if (!send_msg(sockfd, &rsp, sizeof(rsp))) {
return;
}
// Activate transport upgrade using client's caps
sockfd->update_caps(req.conn_caps);
while (true) {
if (!recv_data(sockfd, &cmd, 1)) {
break;
@@ -1884,6 +2418,12 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
if (!parse_endpoint(endpoint, host, port)) {
return;
}
#ifdef GGML_RPC_RDMA
printf(" transport : TCP (RDMA auto-negotiate enabled)\n");
#else
printf(" transport : TCP\n");
#endif // GGML_RPC_RDMA
#ifdef _WIN32
{
WSADATA wsaData;
@@ -1907,7 +2447,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backends, cache_dir, client_socket->fd);
rpc_serve_client(backends, cache_dir, client_socket.get());
printf("Client connection closed\n");
fflush(stdout);
}

View File

@@ -95,6 +95,12 @@ $ bin/rpc-server -c
By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable.
### RDMA transport
On Linux systems with RoCEv2-capable NICs (e.g. Mellanox ConnectX), the RPC backend can use RDMA instead of TCP for lower latency and higher throughput. The transport is negotiated automatically -- no changes to command-line usage are required.
RDMA is enabled by default when `libibverbs` is found at build time.
### Troubleshooting
Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `rpc-server`: