diff --git a/sbin/hastd/proto.c b/sbin/hastd/proto.c index 6c8382fe4126..9791c795348a 100644 --- a/sbin/hastd/proto.c +++ b/sbin/hastd/proto.c @@ -36,6 +36,7 @@ __FBSDID("$FreeBSD$"); #include #include +#include #include "pjdlog.h" #include "proto.h" @@ -68,6 +69,40 @@ proto_register(struct hast_proto *proto, bool isdefault) } } +static struct proto_conn * +proto_alloc(struct hast_proto *proto, int side) +{ + struct proto_conn *conn; + + PJDLOG_ASSERT(proto != NULL); + PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT || + side == PROTO_SIDE_SERVER_LISTEN || + side == PROTO_SIDE_SERVER_WORK); + + conn = malloc(sizeof(*conn)); + if (conn != NULL) { + conn->pc_proto = proto; + conn->pc_side = side; + conn->pc_magic = PROTO_CONN_MAGIC; + } + return (conn); +} + +static void +proto_free(struct proto_conn *conn) +{ + + PJDLOG_ASSERT(conn != NULL); + PJDLOG_ASSERT(conn->pc_magic == PROTO_CONN_MAGIC); + PJDLOG_ASSERT(conn->pc_side == PROTO_SIDE_CLIENT || + conn->pc_side == PROTO_SIDE_SERVER_LISTEN || + conn->pc_side == PROTO_SIDE_SERVER_WORK); + PJDLOG_ASSERT(conn->pc_proto != NULL); + + bzero(conn, sizeof(*conn)); + free(conn); +} + static int proto_common_setup(const char *addr, struct proto_conn **connp, int side) { @@ -76,11 +111,8 @@ proto_common_setup(const char *addr, struct proto_conn **connp, int side) void *ctx; int ret; - PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT || side == PROTO_SIDE_SERVER_LISTEN); - - conn = malloc(sizeof(*conn)); - if (conn == NULL) - return (-1); + PJDLOG_ASSERT(side == PROTO_SIDE_CLIENT || + side == PROTO_SIDE_SERVER_LISTEN); TAILQ_FOREACH(proto, &protos, hp_next) { if (side == PROTO_SIDE_CLIENT) { @@ -104,21 +136,24 @@ proto_common_setup(const char *addr, struct proto_conn **connp, int side) } if (proto == NULL) { /* Unrecognized address. */ - free(conn); errno = EINVAL; return (-1); } if (ret > 0) { /* An error occured. */ - free(conn); errno = ret; return (-1); } - conn->pc_proto = proto; + conn = proto_alloc(proto, side); + if (conn == NULL) { + if (proto->hp_close != NULL) + proto->hp_close(ctx); + errno = ENOMEM; + return (-1); + } conn->pc_ctx = ctx; - conn->pc_side = side; - conn->pc_magic = PROTO_CONN_MAGIC; *connp = conn; + return (0); } @@ -168,20 +203,17 @@ proto_accept(struct proto_conn *conn, struct proto_conn **newconnp) PJDLOG_ASSERT(conn->pc_proto != NULL); PJDLOG_ASSERT(conn->pc_proto->hp_accept != NULL); - newconn = malloc(sizeof(*newconn)); + newconn = proto_alloc(conn->pc_proto, PROTO_SIDE_SERVER_WORK); if (newconn == NULL) return (-1); ret = conn->pc_proto->hp_accept(conn->pc_ctx, &newconn->pc_ctx); if (ret != 0) { - free(newconn); + proto_free(newconn); errno = ret; return (-1); } - newconn->pc_proto = conn->pc_proto; - newconn->pc_side = PROTO_SIDE_SERVER_WORK; - newconn->pc_magic = PROTO_CONN_MAGIC; *newconnp = newconn; return (0); @@ -341,6 +373,5 @@ proto_close(struct proto_conn *conn) PJDLOG_ASSERT(conn->pc_proto->hp_close != NULL); conn->pc_proto->hp_close(conn->pc_ctx); - conn->pc_magic = 0; - free(conn); + proto_free(conn); }