From 8046c499abd461a51517067641dc7a13de88a1cb Mon Sep 17 00:00:00 2001 From: Pawel Jakub Dawidek Date: Mon, 31 Jan 2011 18:35:17 +0000 Subject: [PATCH] Implement two new functions for sending descriptor and receving descriptor over UNIX domain sockets and socket pairs. This is in preparation for capsicum. MFC after: 1 week --- sbin/hastd/proto.c | 36 +++++++++++++++++ sbin/hastd/proto.h | 2 + sbin/hastd/proto_common.c | 75 ++++++++++++++++++++++++++++++++--- sbin/hastd/proto_impl.h | 10 ++++- sbin/hastd/proto_socketpair.c | 30 ++++++++++++++ sbin/hastd/proto_uds.c | 28 +++++++++++++ 6 files changed, 173 insertions(+), 8 deletions(-) diff --git a/sbin/hastd/proto.c b/sbin/hastd/proto.c index 7fa90853a914..a00a8ec5bbf7 100644 --- a/sbin/hastd/proto.c +++ b/sbin/hastd/proto.c @@ -216,6 +216,42 @@ proto_recv(const struct proto_conn *conn, void *data, size_t size) return (0); } +int +proto_descriptor_send(const struct proto_conn *conn, int fd) +{ + int ret; + + PJDLOG_ASSERT(conn != NULL); + PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC); + PJDLOG_ASSERT(conn->pc_proto != NULL); + PJDLOG_ASSERT(conn->pc_proto->hp_descriptor_send != NULL); + + ret = conn->pc_proto->hp_descriptor_send(conn->pc_ctx, fd); + if (ret != 0) { + errno = ret; + return (-1); + } + return (0); +} + +int +proto_descriptor_recv(const struct proto_conn *conn, int *fdp) +{ + int ret; + + PJDLOG_ASSERT(conn != NULL); + PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC); + PJDLOG_ASSERT(conn->pc_proto != NULL); + PJDLOG_ASSERT(conn->pc_proto->hp_descriptor_recv != NULL); + + ret = conn->pc_proto->hp_descriptor_recv(conn->pc_ctx, fdp); + if (ret != 0) { + errno = ret; + return (-1); + } + return (0); +} + int proto_descriptor(const struct proto_conn *conn) { diff --git a/sbin/hastd/proto.h b/sbin/hastd/proto.h index 8d1046caed52..ad44b1ab57c5 100644 --- a/sbin/hastd/proto.h +++ b/sbin/hastd/proto.h @@ -43,6 +43,8 @@ int proto_server(const char *addr, struct proto_conn **connp); int proto_accept(struct proto_conn *conn, struct proto_conn **newconnp); int proto_send(const struct proto_conn *conn, const void *data, size_t size); int proto_recv(const struct proto_conn *conn, void *data, size_t size); +int proto_descriptor_send(const struct proto_conn *conn, int fd); +int proto_descriptor_recv(const struct proto_conn *conn, int *fdp); int proto_descriptor(const struct proto_conn *conn); bool proto_address_match(const struct proto_conn *conn, const char *addr); void proto_local_address(const struct proto_conn *conn, char *addr, diff --git a/sbin/hastd/proto_common.c b/sbin/hastd/proto_common.c index 89c637d7151c..d638d3b82b7a 100644 --- a/sbin/hastd/proto_common.c +++ b/sbin/hastd/proto_common.c @@ -46,18 +46,18 @@ __FBSDID("$FreeBSD$"); #endif int -proto_common_send(int fd, const unsigned char *data, size_t size) +proto_common_send(int sock, const unsigned char *data, size_t size) { ssize_t done; size_t sendsize; - PJDLOG_ASSERT(fd >= 0); + PJDLOG_ASSERT(sock >= 0); PJDLOG_ASSERT(data != NULL); PJDLOG_ASSERT(size > 0); do { sendsize = size < MAX_SEND_SIZE ? size : MAX_SEND_SIZE; - done = send(fd, data, sendsize, MSG_NOSIGNAL); + done = send(sock, data, sendsize, MSG_NOSIGNAL); if (done == 0) return (ENOTCONN); else if (done < 0) { @@ -73,16 +73,16 @@ proto_common_send(int fd, const unsigned char *data, size_t size) } int -proto_common_recv(int fd, unsigned char *data, size_t size) +proto_common_recv(int sock, unsigned char *data, size_t size) { ssize_t done; - PJDLOG_ASSERT(fd >= 0); + PJDLOG_ASSERT(sock >= 0); PJDLOG_ASSERT(data != NULL); PJDLOG_ASSERT(size > 0); do { - done = recv(fd, data, size, MSG_WAITALL); + done = recv(sock, data, size, MSG_WAITALL); } while (done == -1 && errno == EINTR); if (done == 0) return (ENOTCONN); @@ -90,3 +90,66 @@ proto_common_recv(int fd, unsigned char *data, size_t size) return (errno); return (0); } + +int +proto_common_descriptor_send(int sock, int fd) +{ + unsigned char ctrl[CMSG_SPACE(sizeof(fd))]; + struct msghdr msg; + struct cmsghdr *cmsg; + + PJDLOG_ASSERT(sock >= 0); + PJDLOG_ASSERT(fd >= 0); + + bzero(&msg, sizeof(msg)); + bzero(&ctrl, sizeof(ctrl)); + + msg.msg_iov = NULL; + msg.msg_iovlen = 0; + msg.msg_control = ctrl; + msg.msg_controllen = sizeof(ctrl); + + cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); + *((int *)CMSG_DATA(cmsg)) = fd; + + if (sendmsg(sock, &msg, 0) == -1) + return (errno); + + return (0); +} + +int +proto_common_descriptor_recv(int sock, int *fdp) +{ + unsigned char ctrl[CMSG_SPACE(sizeof(*fdp))]; + struct msghdr msg; + struct cmsghdr *cmsg; + + PJDLOG_ASSERT(sock >= 0); + PJDLOG_ASSERT(fdp != NULL); + + bzero(&msg, sizeof(msg)); + bzero(&ctrl, sizeof(ctrl)); + + msg.msg_iov = NULL; + msg.msg_iovlen = 0; + msg.msg_control = ctrl; + msg.msg_controllen = sizeof(ctrl); + + if (recvmsg(sock, &msg, 0) == -1) + return (errno); + + for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; + cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level == SOL_SOCKET && + cmsg->cmsg_type == SCM_RIGHTS) { + *fdp = *((int *)CMSG_DATA(cmsg)); + return (0); + } + } + + return (ENOENT); +} diff --git a/sbin/hastd/proto_impl.h b/sbin/hastd/proto_impl.h index f0dfadd5836b..fcf3d44559ce 100644 --- a/sbin/hastd/proto_impl.h +++ b/sbin/hastd/proto_impl.h @@ -45,6 +45,8 @@ typedef int hp_server_t(const char *, void **); typedef int hp_accept_t(void *, void **); typedef int hp_send_t(void *, const unsigned char *, size_t); typedef int hp_recv_t(void *, unsigned char *, size_t); +typedef int hp_descriptor_send_t(void *, int); +typedef int hp_descriptor_recv_t(void *, int *); typedef int hp_descriptor_t(const void *); typedef bool hp_address_match_t(const void *, const char *); typedef void hp_local_address_t(const void *, char *, size_t); @@ -59,6 +61,8 @@ struct hast_proto { hp_accept_t *hp_accept; hp_send_t *hp_send; hp_recv_t *hp_recv; + hp_descriptor_send_t *hp_descriptor_send; + hp_descriptor_recv_t *hp_descriptor_recv; hp_descriptor_t *hp_descriptor; hp_address_match_t *hp_address_match; hp_local_address_t *hp_local_address; @@ -69,7 +73,9 @@ struct hast_proto { void proto_register(struct hast_proto *proto, bool isdefault); -int proto_common_send(int fd, const unsigned char *data, size_t size); -int proto_common_recv(int fd, unsigned char *data, size_t size); +int proto_common_send(int sock, const unsigned char *data, size_t size); +int proto_common_recv(int sock, unsigned char *data, size_t size); +int proto_common_descriptor_send(int sock, int fd); +int proto_common_descriptor_recv(int sock, int *fdp); #endif /* !_PROTO_IMPL_H_ */ diff --git a/sbin/hastd/proto_socketpair.c b/sbin/hastd/proto_socketpair.c index 34f28b7ebfcc..1bb02f6690f8 100644 --- a/sbin/hastd/proto_socketpair.c +++ b/sbin/hastd/proto_socketpair.c @@ -160,6 +160,34 @@ sp_recv(void *ctx, unsigned char *data, size_t size) return (proto_common_recv(fd, data, size)); } +static int +sp_descriptor_send(void *ctx, int fd) +{ + struct sp_ctx *spctx = ctx; + + PJDLOG_ASSERT(spctx != NULL); + PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC); + PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_CLIENT); + PJDLOG_ASSERT(spctx->sp_fd[0] >= 0); + PJDLOG_ASSERT(fd > 0); + + return (proto_common_descriptor_send(spctx->sp_fd[0], fd)); +} + +static int +sp_descriptor_recv(void *ctx, int *fdp) +{ + struct sp_ctx *spctx = ctx; + + PJDLOG_ASSERT(spctx != NULL); + PJDLOG_ASSERT(spctx->sp_magic == SP_CTX_MAGIC); + PJDLOG_ASSERT(spctx->sp_side == SP_SIDE_SERVER); + PJDLOG_ASSERT(spctx->sp_fd[1] >= 0); + PJDLOG_ASSERT(fdp != NULL); + + return (proto_common_descriptor_recv(spctx->sp_fd[1], fdp)); +} + static int sp_descriptor(const void *ctx) { @@ -224,6 +252,8 @@ static struct hast_proto sp_proto = { .hp_client = sp_client, .hp_send = sp_send, .hp_recv = sp_recv, + .hp_descriptor_send = sp_descriptor_send, + .hp_descriptor_recv = sp_descriptor_recv, .hp_descriptor = sp_descriptor, .hp_close = sp_close }; diff --git a/sbin/hastd/proto_uds.c b/sbin/hastd/proto_uds.c index a114c21531d6..262d0c01a9a5 100644 --- a/sbin/hastd/proto_uds.c +++ b/sbin/hastd/proto_uds.c @@ -225,6 +225,32 @@ uds_recv(void *ctx, unsigned char *data, size_t size) return (proto_common_recv(uctx->uc_fd, data, size)); } +static int +uds_descriptor_send(void *ctx, int fd) +{ + struct uds_ctx *uctx = ctx; + + PJDLOG_ASSERT(uctx != NULL); + PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC); + PJDLOG_ASSERT(uctx->uc_fd >= 0); + PJDLOG_ASSERT(fd >= 0); + + return (proto_common_descriptor_send(uctx->uc_fd, fd)); +} + +static int +uds_descriptor_recv(void *ctx, int *fdp) +{ + struct uds_ctx *uctx = ctx; + + PJDLOG_ASSERT(uctx != NULL); + PJDLOG_ASSERT(uctx->uc_magic == UDS_CTX_MAGIC); + PJDLOG_ASSERT(uctx->uc_fd >= 0); + PJDLOG_ASSERT(fdp != NULL); + + return (proto_common_descriptor_recv(uctx->uc_fd, fdp)); +} + static int uds_descriptor(const void *ctx) { @@ -307,6 +333,8 @@ static struct hast_proto uds_proto = { .hp_accept = uds_accept, .hp_send = uds_send, .hp_recv = uds_recv, + .hp_descriptor_send = uds_descriptor_send, + .hp_descriptor_recv = uds_descriptor_recv, .hp_descriptor = uds_descriptor, .hp_local_address = uds_local_address, .hp_remote_address = uds_remote_address,