diff --git a/.gitignore b/.gitignore index a498fdc..f1100cb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # ---> C # Prerequisites *.d +.cache +build +.vscode # Object files *.o diff --git a/CMakeLists.txt b/CMakeLists.txt index 7454b33..15c1795 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,46 +8,40 @@ add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/msg/msg.pb.cc COMMAND protoc --cpp_out=${CMAKE_CURRENT_BINARY_DIR}/msg/ --proto_path=${CMAKE_CURRENT_SOURCE_DIR}/msg/ msg.proto DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/msg/msg.proto) +set(CMAKE_EXPORT_COMPILE_COMMANDS True) find_package(PkgConfig REQUIRED) -pkg_check_modules(rocksdb rocksdb) -pkg_check_modules(protobuf REQUIRED protobuf) +pkg_check_modules(bsock REQUIRED bsock) -if (${ENABLE_FSTACK} MATCHES "y") - pkg_check_modules(dpdk REQUIRED libdpdk) - pkg_check_modules(bsdtopo REQUIRED bsdtopo) - pkg_check_modules(ssl REQUIRED libssl) - include_directories(${dpdk_INCLUDE_DIRS}) - include_directories(${ssl_INCLUDE_DIRS}) - include_directories(${bsdtopo_INCLUDE_DIRS}) -endif() +# if (${ENABLE_FSTACK} MATCHES "y") +# pkg_check_modules(dpdk REQUIRED libdpdk) +# pkg_check_modules(bsdtopo REQUIRED bsdtopo) +# pkg_check_modules(ssl REQUIRED libssl) +# include_directories(${dpdk_INCLUDE_DIRS}) +# include_directories(${ssl_INCLUDE_DIRS}) +# include_directories(${bsdtopo_INCLUDE_DIRS}) +# endif() -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -include_directories(${rocksdb_INCLUDE_DIRS}) -include_directories(${CMAKE_CURRENT_BINARY_DIR}/msg) -include_directories(${protobuf_INCLUDE_DIRS}) set(CFLAGS -Wall -Wextra -Werror -Wno-unused-parameter -Wno-unused-variable -std=c++17 -O2 -g) -add_executable(dismember ${CMAKE_CURRENT_SOURCE_DIR}/dismember/dismember.cc - ${CMAKE_CURRENT_SOURCE_DIR}/dismember/Generator.cc - ${CMAKE_CURRENT_SOURCE_DIR}/dismember/reqgen.cc - ${CMAKE_CURRENT_SOURCE_DIR}/dismember/util.cc - ${CMAKE_CURRENT_BINARY_DIR}/msg/msg.pb.cc) -target_link_libraries(dismember ${protobuf_LINK_LIBRARIES} ${rocksdb_LINK_LIBRARIES} bz2 z pthread) -target_compile_options(dismember PRIVATE ${CFLAGS}) +add_executable(dsmbr ${CMAKE_CURRENT_SOURCE_DIR}/ppd/dsmbr.cc + ${CMAKE_CURRENT_SOURCE_DIR}/ppd/util.cc) +target_link_libraries(dsmbr pthread bsock) +target_compile_options(dsmbr PRIVATE ${CFLAGS} ${bsock_CFLAGS}) +target_include_directories(dsmbr PRIVATE ${bsock_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/include) add_executable(ppd ${CMAKE_CURRENT_SOURCE_DIR}/ppd/ppd.cc - ${CMAKE_CURRENT_SOURCE_DIR}/ppd/reqproc.cc - ${CMAKE_CURRENT_BINARY_DIR}/msg/msg.pb.cc) -target_link_libraries(ppd ${protobuf_LINK_LIBRARIES} ${rocksdb_LINK_LIBRARIES} bz2 z pthread) -target_compile_options(ppd PRIVATE ${CFLAGS}) + ${CMAKE_CURRENT_SOURCE_DIR}/ppd/util.cc) +target_link_libraries(ppd pthread bsock ${bsock_CFLAGS}) +target_compile_options(ppd PRIVATE ${CFLAGS} ${bsock_INCLUDE_DIRS}) +target_include_directories(ppd PRIVATE ${bsock_INCLUDE_DIRS} ${CMAKE_CURRENT_SOURCE_DIR}/include) -if (${ENABLE_FSTACK} MATCHES "y") - add_executable(ppd_ff ${CMAKE_CURRENT_SOURCE_DIR}/ppd_ff/ppd.cc - ${CMAKE_CURRENT_SOURCE_DIR}/ppd_ff/reqproc.cc - ${CMAKE_CURRENT_BINARY_DIR}/msg/msg.pb.cc) - target_link_libraries(ppd_ff ${protobuf_LINK_LIBRARIES} fstack ${ssl_LINK_LIBRARIES} bz2 z crypto ${dpdk_LIBRARIES} ${bsdtopo_LIBRARIES} librte_net_bond.a librte_bus_vdev.a) - target_link_directories(ppd_ff PRIVATE /usr/local/lib ${dpdk_LIBRARY_DIRS} ${bsdtopo_LIBRARY_DIRS}) - target_compile_options(ppd_ff PRIVATE ${CFLAGS} ${dpdk_CFLAGS}) -endif() +# if (${ENABLE_FSTACK} MATCHES "y") +# add_executable(ppd_ff ${CMAKE_CURRENT_SOURCE_DIR}/ppd_ff/ppd.cc +# ${CMAKE_CURRENT_SOURCE_DIR}/ppd_ff/reqproc.cc +# ${CMAKE_CURRENT_BINARY_DIR}/msg/msg.pb.cc) +# target_link_libraries(ppd_ff ${protobuf_LINK_LIBRARIES} fstack ${ssl_LINK_LIBRARIES} bz2 z crypto ${dpdk_LIBRARIES} ${bsdtopo_LIBRARIES} librte_net_bond.a librte_bus_vdev.a) +# target_link_directories(ppd_ff PRIVATE /usr/local/lib ${dpdk_LIBRARY_DIRS} ${bsdtopo_LIBRARY_DIRS}) +# target_compile_options(ppd_ff PRIVATE ${CFLAGS} ${dpdk_CFLAGS}) +# endif() diff --git a/ppd/ppd.cc b/ppd/ppd.cc index 5535177..7b8f4f1 100755 --- a/ppd/ppd.cc +++ b/ppd/ppd.cc @@ -15,6 +15,8 @@ #include #include +#include + #include "logger.h" #include "mod.h" #include "msg.h" @@ -35,6 +37,7 @@ static constexpr int NEVENT = 64; static constexpr int SOCK_BACKLOG = 10000; static constexpr int SINGLE_LEGACY = -1; static constexpr int DEFAULT_PORT = 9898; +static constexpr int BSOCK_BUF_SZ = 4096; // 16MB max per message static constexpr int MBUF_SZ = 1024 * 1024 * 16; static constexpr int MAX_MODE_PARAMS = 16; @@ -71,8 +74,12 @@ struct ppd_options { }; struct ppd_conn { + struct bsock * bsock; int conn_fd; - SSL *ssl; + + SSL * ssl; + char * ssl_readbuf; + struct ppd_bsock_io_ssl_ctx ssl_io_ctx; void *m_conn_ctx; }; @@ -215,6 +222,12 @@ ppd_conn_free_no_ctx(struct ppd_conn *conn) SSL_shutdown(conn->ssl); SSL_free(conn->ssl); } + if (conn->bsock != nullptr) { + bsock_free(conn->bsock); + } + if (conn->ssl_readbuf != nullptr) { + delete[] conn->ssl_readbuf; + } close(conn->conn_fd); delete conn; } @@ -266,13 +279,25 @@ handle_event(struct ppd_thread_ctx *tinfo, struct kevent *kev) goto fail; } - status = ppd_readmsg(conn_fd, hint->ssl, tinfo->m_buf, MBUF_SZ); - if (status != 0) { - W("Thread %d dropped connection %d due to ppd_readmsg error %d\n", tinfo->tid, - conn_fd, errno); + // read data first + status = bsock_poll(hint->bsock); + if (status == 0) { + // connection reset basically + W("Thread %d dropped connection %d due to bsock_poll ret %d errno %d\n", tinfo->tid, conn_fd, status, errno); goto fail; } + status = ppd_readmsg(hint->bsock, tinfo->m_buf, MBUF_SZ); + if (status != 0) { + if (errno == ERANGE) { + // not enough data yet. try again later. + goto end; + } else { + W("Thread %d dropped connection %d due to ppd_readmsg error %d\n", tinfo->tid, conn_fd, errno); + goto fail; + } + } + msg = (struct ppd_msg *)tinfo->m_buf; status = options.m_info->conn_recv_cb(msg->data, msg->size, options.m_global_ctx, tinfo->m_thread_ctx, hint->m_conn_ctx); @@ -291,14 +316,22 @@ handle_event(struct ppd_thread_ctx *tinfo, struct kevent *kev) } msg->size = out_sz; - status = ppd_writemsg(conn_fd, hint->ssl, msg); + status = ppd_writemsg(hint->bsock, msg); if (status != 0) { - W("Thread %d dropped connection %d due to ppd_writemsg error %d\n", tinfo->tid, - conn_fd, errno); + // shouldn't be error here unless msg is too big to fit in bsock buffer + W("Thread %d dropped connection %d due to ppd_writemsg error %d\n", tinfo->tid, conn_fd, errno); + goto fail; + } + + // flush bsock immediately + status = bsock_flush(hint->bsock); + if (status <= 0) { + W("Thread %d dropped connection %d due to bsock_flush ret %d errno %d\n", tinfo->tid, conn_fd, status, errno); goto fail; } tinfo->evcnt++; +end: return 0; fail: drop_conn(tinfo, kev); @@ -589,6 +622,7 @@ loop_main(int m_kq, std::vector *workers) 0) { W("setsockopt() nodelay failed on conn %d: err %d\n", conn_fd, errno); + close(conn_fd); continue; } } @@ -605,9 +639,24 @@ loop_main(int m_kq, std::vector *workers) if (options.enable_tls) { conn->ssl = tls_handshake_server(conn_fd); + conn->ssl_readbuf = new char[BSOCK_BUF_SZ]; + struct bsock_ringbuf_io io = ppd_bsock_io_ssl(); + conn->ssl_io_ctx.ssl_readbuf = conn->ssl_readbuf; + conn->ssl_io_ctx.ssl = conn->ssl; + conn->ssl_io_ctx.ssl_readbuf_len = BSOCK_BUF_SZ; V("Established TLS on connection %d...\n", conn_fd); + conn->bsock = bsock_create((void*)&conn->ssl_io_ctx, &io, BSOCK_BUF_SZ, BSOCK_BUF_SZ); } else { conn->ssl = nullptr; + conn->ssl_readbuf = nullptr; + struct bsock_ringbuf_io io = bsock_io_posix(); + conn->bsock = bsock_create((void*)(uintptr_t)conn_fd, &io, BSOCK_BUF_SZ, BSOCK_BUF_SZ); + } + + if (conn->bsock == nullptr) { + W("Failed to create bsock on connection %d...\n", conn_fd); + ppd_conn_free_no_ctx(conn); + continue; } int worker_idx = cur_conn % workers->size(); diff --git a/ppd/util.cc b/ppd/util.cc index 077c762..ff16230 100644 --- a/ppd/util.cc +++ b/ppd/util.cc @@ -5,10 +5,12 @@ #include #include +#include "bsock/bsock.h" #include "logger.h" #include "msg.h" #include "util.h" +#include #include struct ppd_mod_info * @@ -27,149 +29,141 @@ ppd_load_module(const char *path) return fn(); } -static int -ppd_ssl_error_retryable(int err) +static ssize_t +ppd_read_ssl(void * _ctx, void *buf, size_t len) { - return (err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE) || - (err == SSL_ERROR_WANT_CONNECT) || (err == SSL_ERROR_WANT_ACCEPT) || - (err == SSL_ERROR_WANT_X509_LOOKUP) || (err == SSL_ERROR_WANT_CLIENT_HELLO_CB); + struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; + int status = SSL_read(ctx->ssl, buf, len); + if (status > 0) { + return status; + } + errno = SSL_get_error(ctx->ssl, status); + return -1; } -int -ppd_readbuf_ssl(SSL *ssl, void *buf, int len) +static ssize_t +ppd_write_ssl(void * _ctx, void *buf, size_t len) { - int status; + struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; + int status = SSL_write(ctx->ssl, buf, len); + if (status > 0) { + return status; + } + errno = SSL_get_error(ctx->ssl, status); + return -1; +} - while (len > 0) { - if ((status = SSL_read(ssl, buf, len)) > 0) { - buf = (char *)buf + status; - len -= status; - } else { - status = SSL_get_error(ssl, status); - if (!ppd_ssl_error_retryable(status)) { - errno = status; - return -1; - } +static ssize_t +ppd_readv_ssl(void * _ctx, const struct iovec * vec, int nvec) +{ + struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; + + size_t total_sz = 0; + for(int i = 0; i < nvec; i++) { + total_sz += (vec + i)->iov_len; + } + + if (total_sz > ctx->ssl_readbuf_len) { + total_sz = ctx->ssl_readbuf_len; + } + + int read_size = SSL_read(ctx->ssl, ctx->ssl_readbuf, total_sz); + if (read_size <= 0) { + errno = SSL_get_error(ctx->ssl, read_size); + return -1; + } + + int copied = 0; + for (int i = 0; i < nvec; i++) { + int cur_cpy = std::min((int)vec[i].iov_len, read_size - copied); + + memcpy(vec[i].iov_base, ctx->ssl_readbuf + copied, cur_cpy); + copied += cur_cpy; + + if (copied == read_size) { + break; } - }; + } - return 0; + return read_size; } -int -ppd_writebuf_ssl(SSL *ssl, void *buf, int len) +static ssize_t +ppd_writev_ssl(void * _ctx, const struct iovec * vec, int nvec) { - int status; + struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; - while (len > 0) { - if ((status = SSL_write(ssl, buf, len)) > 0) { - buf = (char *)buf + status; - len -= status; - } else { - status = SSL_get_error(ssl, status); - if (!ppd_ssl_error_retryable(status)) { - errno = status; - return -1; - } + int copied = 0; + for(int i = 0; i < nvec; i++) { + int len = (vec + i)->iov_len; + int cur_copy = std::min(len, (int)ctx->ssl_readbuf_len - copied); + memcpy(ctx->ssl_readbuf + copied, vec->iov_base, len); + copied += cur_copy; + + if (copied == (int)ctx->ssl_readbuf_len) { + break; } - }; + } - return 0; + int write_size = SSL_write(ctx->ssl, ctx->ssl_readbuf, copied); + if (write_size <= 0) { + errno = SSL_get_error(ctx->ssl, write_size); + return -1; + } + + return write_size; } -int -ppd_readbuf(int fd, void *buf, int len) +struct bsock_ringbuf_io ppd_bsock_io_ssl() { - int status; - - while (len > 0) { - if ((status = recv(fd, buf, len, 0)) > 0) { - buf = (char *)buf + status; - len -= status; - } else if (status == 0) { - errno = ECONNRESET; - return -1; - } else { - if (errno != EINTR) { - return -1; - } - } - }; - - return 0; + struct bsock_ringbuf_io io; + io.read = &ppd_read_ssl; + io.readv = &ppd_readv_ssl; + io.write = &ppd_write_ssl; + io.writev = &ppd_writev_ssl; + return io; } int -ppd_writebuf(int fd, void *buf, int len) -{ - int status; - - while (len > 0) { - if ((status = send(fd, buf, len, 0)) > 0) { - buf = (char *)buf + status; - len -= status; - } else if (status == 0) { - errno = ECONNRESET; - return -1; - } else { - return -1; - } - }; - - return 0; -} - -int -ppd_readmsg(int fd, SSL *ssl, char *buf, size_t len) +ppd_readmsg(struct bsock * bsock, char *buf, size_t len) { int status; struct ppd_msg *msg = (struct ppd_msg *)buf; if (len < sizeof(struct ppd_msg)) { - return EOVERFLOW; + errno = EOVERFLOW; + return -1; } - if (ssl != nullptr) { - status = ppd_readbuf_ssl(ssl, msg, sizeof(struct ppd_msg)); - } else { - status = ppd_readbuf(fd, msg, sizeof(struct ppd_msg)); - } + status = bsock_peek(bsock, buf, sizeof(struct ppd_msg)); if (status != 0) { return status; } int sz = ntohl(msg->size); - msg->size = sz; if (sz > (int)len) { - return EOVERFLOW; + errno = EOVERFLOW; + return -1; } - if (((struct ppd_msg *)buf)->size > 0) { - if (ssl != nullptr) { - status = ppd_readbuf_ssl(ssl, buf, sz); - } else { - status = ppd_readbuf(fd, buf, sz); - } + status = bsock_read(bsock, buf, sizeof(struct ppd_msg) + sz); + if (status != 0) { + return status; } + msg->size = sz; return status; } int -ppd_writemsg(int fd, SSL *ssl, struct ppd_msg *msg) +ppd_writemsg(struct bsock * bsock, struct ppd_msg *msg) { int status; int sz = msg->size; msg->size = htonl(msg->size); - if (ssl != nullptr) { - status = ppd_writebuf_ssl(ssl, msg, sizeof(struct ppd_msg) + sz); - } else { - status = ppd_writebuf(fd, msg, sizeof(struct ppd_msg) + sz); - } - - return status; + return bsock_write(bsock, (char *)msg, sizeof(struct ppd_msg) + sz); } diff --git a/ppd/util.h b/ppd/util.h index f26ec90..c33aabf 100644 --- a/ppd/util.h +++ b/ppd/util.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "mod.h" @@ -27,17 +28,25 @@ cpulist_to_cpuset(char *cpulist, cpuset_t *cpuset) } } +struct ppd_bsock_io_ssl_ctx { + SSL * ssl; + char * ssl_readbuf; + size_t ssl_readbuf_len; +}; + +struct bsock_ringbuf_io ppd_bsock_io_ssl(); + int ppd_readbuf_ssl(SSL *ssl, void *buf, int len); int ppd_writebuf_ssl(SSL *ssl, void *buf, int len); -int ppd_readbuf(int fd, void *buf, int len); +int ppd_readbuf(struct bsock *bsock, int len); -int ppd_writebuf(int fd, void *buf, int len); +int ppd_writebuf(struct bsock *bsock, int len); -int ppd_readmsg(int fd, SSL *ssl, char *buf, size_t len); +int ppd_readmsg(struct bsock *bsock, char *buf, size_t len); -int ppd_writemsg(int fd, SSL *ssl, struct ppd_msg *msg); +int ppd_writemsg(struct bsock *bsock, struct ppd_msg *msg); static inline uint64_t get_time_us()