#include #include #include #include #include #include #include "logger.h" #include "msg.hh" #include "io.hh" #include static ssize_t ppd_read_ssl(void * _ctx, void *buf, size_t len) { struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; int status = SSL_read(ctx->ssl, buf, len); if (status > 0) { return status; } errno = SSL_get_error(ctx->ssl, status); return -1; } static ssize_t ppd_write_ssl(void * _ctx, void *buf, size_t len) { struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; int status = SSL_write(ctx->ssl, buf, len); if (status > 0) { return status; } errno = SSL_get_error(ctx->ssl, status); return -1; } static ssize_t ppd_readv_ssl(void * _ctx, const struct iovec * vec, int nvec) { struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; size_t total_sz = 0; for(int i = 0; i < nvec; i++) { total_sz += (vec + i)->iov_len; } if (total_sz > ctx->ssl_readbuf_len) { total_sz = ctx->ssl_readbuf_len; } int read_size = SSL_read(ctx->ssl, ctx->ssl_readbuf, total_sz); if (read_size <= 0) { errno = SSL_get_error(ctx->ssl, read_size); return -1; } int copied = 0; for (int i = 0; i < nvec; i++) { int cur_cpy = std::min((int)vec[i].iov_len, read_size - copied); memcpy(vec[i].iov_base, ctx->ssl_readbuf + copied, cur_cpy); copied += cur_cpy; if (copied == read_size) { break; } } return read_size; } static ssize_t ppd_writev_ssl(void * _ctx, const struct iovec * vec, int nvec) { struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx; int copied = 0; for(int i = 0; i < nvec; i++) { int len = (vec + i)->iov_len; int cur_copy = std::min(len, (int)ctx->ssl_readbuf_len - copied); memcpy(ctx->ssl_readbuf + copied, vec->iov_base, len); copied += cur_copy; if (copied == (int)ctx->ssl_readbuf_len) { break; } } int write_size = SSL_write(ctx->ssl, ctx->ssl_readbuf, copied); if (write_size <= 0) { errno = SSL_get_error(ctx->ssl, write_size); return -1; } return write_size; } struct bsock_ringbuf_io ppd_bsock_io_ssl() { struct bsock_ringbuf_io io; io.read = &ppd_read_ssl; io.readv = &ppd_readv_ssl; io.write = &ppd_write_ssl; io.writev = &ppd_writev_ssl; return io; } /* * returns: 0 on success * -1 + errno on fail * recoverable errnos: * ENOMEM - buffer is not big enough to hold message * ENODATA - bsock does not have enough data for the entire message * unrecoverable errnos: * EIO - libbsock has enough data but didn't give us */ int ppd_readmsg(struct bsock * bsock, char *buf, size_t len) { int status; struct ppd_msg *msg = (struct ppd_msg *)buf; if (len < PPD_MSG_HDR_SZ) { errno = EINVAL; return -1; } status = bsock_peek(bsock, buf, PPD_MSG_HDR_SZ); if (status < 0) { return status; } else if (status != PPD_MSG_HDR_SZ) { errno = ENODATA; return -1; } int sz = ntohl(msg->size); if (sz > (int)(len - PPD_MSG_HDR_SZ)) { errno = ENOMEM; return -1; } if (bsock_read_avail_size(bsock) < (int)(PPD_MSG_HDR_SZ + sz)) { errno = ENODATA; return -1; } status = bsock_read(bsock, buf, PPD_MSG_HDR_SZ + sz); if (status < 0) { return status; } else if (status != PPD_MSG_HDR_SZ + sz) { // this should never happen unless there is a bug with libbsock errno = EIO; return -1; } msg->size = sz; return 0; } /* * returns: 0 on success * -1 + errno on fail * unrecoverable errnos: * EAGAIN - not all data is written. How much data is written is unknown. */ int ppd_writemsg(struct bsock * bsock, struct ppd_msg *msg) { int status; int sz = msg->size; msg->size = htonl(sz); status = bsock_write(bsock, (char *)msg, PPD_MSG_HDR_SZ + sz); if (status < 0) { // not all message was sent return status; } else if (status != PPD_MSG_HDR_SZ + sz) { errno = EAGAIN; return -1; } return 0; } int ppd_readbuf(int fd, void *buf, int len) { int status; while (len > 0) { if ((status = recv(fd, buf, len, 0)) > 0) { buf = (char *)buf + status; len -= status; } else if (status == 0) { errno = ECONNRESET; return -1; } else { if (errno != EINTR) { return -1; } } }; return 0; } int ppd_writebuf(int fd, void *buf, int len) { int status; while (len > 0) { if ((status = send(fd, buf, len, 0)) > 0) { buf = (char *)buf + status; len -= status; } else if (status == 0) { errno = ECONNRESET; return -1; } else { return -1; } }; return 0; }