ppd/common/io.cc
2023-03-13 09:49:28 +01:00

230 lines
4.5 KiB
C++

#include <netinet/in.h>
#include <errno.h>
#include <openssl/ssl.h>
#include <stdio.h>
#include <unistd.h>
#include <bsock/bsock.h>
#include "logger.h"
#include "msg.hh"
#include "io.hh"
#include <algorithm>
static ssize_t
ppd_read_ssl(void * _ctx, void *buf, size_t len)
{
struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx;
int status = SSL_read(ctx->ssl, buf, len);
if (status > 0) {
return status;
}
errno = SSL_get_error(ctx->ssl, status);
return -1;
}
static ssize_t
ppd_write_ssl(void * _ctx, void *buf, size_t len)
{
struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx;
int status = SSL_write(ctx->ssl, buf, len);
if (status > 0) {
return status;
}
errno = SSL_get_error(ctx->ssl, status);
return -1;
}
static ssize_t
ppd_readv_ssl(void * _ctx, const struct iovec * vec, int nvec)
{
struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx;
size_t total_sz = 0;
for(int i = 0; i < nvec; i++) {
total_sz += (vec + i)->iov_len;
}
if (total_sz > ctx->ssl_readbuf_len) {
total_sz = ctx->ssl_readbuf_len;
}
int read_size = SSL_read(ctx->ssl, ctx->ssl_readbuf, total_sz);
if (read_size <= 0) {
errno = SSL_get_error(ctx->ssl, read_size);
return -1;
}
int copied = 0;
for (int i = 0; i < nvec; i++) {
int cur_cpy = std::min((int)vec[i].iov_len, read_size - copied);
memcpy(vec[i].iov_base, ctx->ssl_readbuf + copied, cur_cpy);
copied += cur_cpy;
if (copied == read_size) {
break;
}
}
return read_size;
}
static ssize_t
ppd_writev_ssl(void * _ctx, const struct iovec * vec, int nvec)
{
struct ppd_bsock_io_ssl_ctx * ctx = (struct ppd_bsock_io_ssl_ctx *)_ctx;
int copied = 0;
for(int i = 0; i < nvec; i++) {
int len = (vec + i)->iov_len;
int cur_copy = std::min(len, (int)ctx->ssl_readbuf_len - copied);
memcpy(ctx->ssl_readbuf + copied, vec->iov_base, len);
copied += cur_copy;
if (copied == (int)ctx->ssl_readbuf_len) {
break;
}
}
int write_size = SSL_write(ctx->ssl, ctx->ssl_readbuf, copied);
if (write_size <= 0) {
errno = SSL_get_error(ctx->ssl, write_size);
return -1;
}
return write_size;
}
struct bsock_ringbuf_io ppd_bsock_io_ssl()
{
struct bsock_ringbuf_io io;
io.read = &ppd_read_ssl;
io.readv = &ppd_readv_ssl;
io.write = &ppd_write_ssl;
io.writev = &ppd_writev_ssl;
return io;
}
/*
* returns: 0 on success
* -1 + errno on fail
* recoverable errnos:
* ENOMEM - buffer is not big enough to hold message
* ENODATA - bsock does not have enough data for the entire message
* unrecoverable errnos:
* EIO - libbsock has enough data but didn't give us
*/
int
ppd_readmsg(struct bsock * bsock, char *buf, size_t len)
{
int status;
struct ppd_msg *msg = (struct ppd_msg *)buf;
if (len < PPD_MSG_HDR_SZ) {
errno = EINVAL;
return -1;
}
status = bsock_peek(bsock, buf, PPD_MSG_HDR_SZ);
if (status < 0) {
return status;
} else if (status != PPD_MSG_HDR_SZ) {
errno = ENODATA;
return -1;
}
int sz = ntohl(msg->size);
if (sz > (int)(len - PPD_MSG_HDR_SZ)) {
errno = ENOMEM;
return -1;
}
if (bsock_read_avail_size(bsock) < (int)(PPD_MSG_HDR_SZ + sz)) {
errno = ENODATA;
return -1;
}
status = bsock_read(bsock, buf, PPD_MSG_HDR_SZ + sz);
if (status < 0) {
return status;
} else if (status != PPD_MSG_HDR_SZ + sz) {
// this should never happen unless there is a bug with libbsock
errno = EIO;
return -1;
}
msg->size = sz;
return 0;
}
/*
* returns: 0 on success
* -1 + errno on fail
* unrecoverable errnos:
* EAGAIN - not all data is written. How much data is written is unknown.
*/
int
ppd_writemsg(struct bsock * bsock, struct ppd_msg *msg)
{
int status;
int sz = msg->size;
msg->size = htonl(sz);
status = bsock_write(bsock, (char *)msg, PPD_MSG_HDR_SZ + sz);
if (status < 0) {
// not all message was sent
return status;
} else if (status != PPD_MSG_HDR_SZ + sz) {
errno = EAGAIN;
return -1;
}
return 0;
}
int
ppd_readbuf(int fd, void *buf, int len)
{
int status;
while (len > 0) {
if ((status = recv(fd, buf, len, 0)) > 0) {
buf = (char *)buf + status;
len -= status;
} else if (status == 0) {
errno = ECONNRESET;
return -1;
} else {
if (errno != EINTR) {
return -1;
}
}
};
return 0;
}
int
ppd_writebuf(int fd, void *buf, int len)
{
int status;
while (len > 0) {
if ((status = send(fd, buf, len, 0)) > 0) {
buf = (char *)buf + status;
len -= status;
} else if (status == 0) {
errno = ECONNRESET;
return -1;
} else {
return -1;
}
};
return 0;
}