inpcb: use family specific sockaddr argument for bind functions

Do the cast from sockaddr to either IPv4 or IPv6 sockaddr in the
protocol's pr_bind method and from there on go down the call
stack with family specific argument.

Reviewed by:		zlei, melifaro, markj
Differential Revision:	https://reviews.freebsd.org/D38601
This commit is contained in:
Gleb Smirnoff 2023-02-15 10:30:16 -08:00
parent 5dc00f00b7
commit 96871af013
7 changed files with 26 additions and 34 deletions

View File

@ -655,21 +655,21 @@ in_pcballoc(struct socket *so, struct inpcbinfo *pcbinfo)
#ifdef INET #ifdef INET
int int
in_pcbbind(struct inpcb *inp, struct sockaddr *nam, struct ucred *cred) in_pcbbind(struct inpcb *inp, struct sockaddr_in *sin, struct ucred *cred)
{ {
int anonport, error; int anonport, error;
KASSERT(nam == NULL || nam->sa_family == AF_INET, KASSERT(sin == NULL || sin->sin_family == AF_INET,
("%s: invalid address family for %p", __func__, nam)); ("%s: invalid address family for %p", __func__, sin));
KASSERT(nam == NULL || nam->sa_len == sizeof(struct sockaddr_in), KASSERT(sin == NULL || sin->sin_len == sizeof(struct sockaddr_in),
("%s: invalid address length for %p", __func__, nam)); ("%s: invalid address length for %p", __func__, sin));
INP_WLOCK_ASSERT(inp); INP_WLOCK_ASSERT(inp);
INP_HASH_WLOCK_ASSERT(inp->inp_pcbinfo); INP_HASH_WLOCK_ASSERT(inp->inp_pcbinfo);
if (inp->inp_lport != 0 || inp->inp_laddr.s_addr != INADDR_ANY) if (inp->inp_lport != 0 || inp->inp_laddr.s_addr != INADDR_ANY)
return (EINVAL); return (EINVAL);
anonport = nam == NULL || ((struct sockaddr_in *)nam)->sin_port == 0; anonport = sin == NULL || sin->sin_port == 0;
error = in_pcbbind_setup(inp, nam, &inp->inp_laddr.s_addr, error = in_pcbbind_setup(inp, sin, &inp->inp_laddr.s_addr,
&inp->inp_lport, cred); &inp->inp_lport, cred);
if (error) if (error)
return (error); return (error);
@ -901,11 +901,10 @@ in_pcbbind_check_bindmulti(const struct inpcb *ni, const struct inpcb *oi)
* On error, the values of *laddrp and *lportp are not changed. * On error, the values of *laddrp and *lportp are not changed.
*/ */
int int
in_pcbbind_setup(struct inpcb *inp, struct sockaddr *nam, in_addr_t *laddrp, in_pcbbind_setup(struct inpcb *inp, struct sockaddr_in *sin, in_addr_t *laddrp,
u_short *lportp, struct ucred *cred) u_short *lportp, struct ucred *cred)
{ {
struct socket *so = inp->inp_socket; struct socket *so = inp->inp_socket;
struct sockaddr_in *sin;
struct inpcbinfo *pcbinfo = inp->inp_pcbinfo; struct inpcbinfo *pcbinfo = inp->inp_pcbinfo;
struct in_addr laddr; struct in_addr laddr;
u_short lport = 0; u_short lport = 0;
@ -925,15 +924,14 @@ in_pcbbind_setup(struct inpcb *inp, struct sockaddr *nam, in_addr_t *laddrp,
INP_HASH_LOCK_ASSERT(pcbinfo); INP_HASH_LOCK_ASSERT(pcbinfo);
laddr.s_addr = *laddrp; laddr.s_addr = *laddrp;
if (nam != NULL && laddr.s_addr != INADDR_ANY) if (sin != NULL && laddr.s_addr != INADDR_ANY)
return (EINVAL); return (EINVAL);
if ((so->so_options & (SO_REUSEADDR|SO_REUSEPORT|SO_REUSEPORT_LB)) == 0) if ((so->so_options & (SO_REUSEADDR|SO_REUSEPORT|SO_REUSEPORT_LB)) == 0)
lookupflags = INPLOOKUP_WILDCARD; lookupflags = INPLOOKUP_WILDCARD;
if (nam == NULL) { if (sin == NULL) {
if ((error = prison_local_ip4(cred, &laddr)) != 0) if ((error = prison_local_ip4(cred, &laddr)) != 0)
return (error); return (error);
} else { } else {
sin = (struct sockaddr_in *)nam;
KASSERT(sin->sin_family == AF_INET, KASSERT(sin->sin_family == AF_INET,
("%s: invalid family for address %p", __func__, sin)); ("%s: invalid family for address %p", __func__, sin));
KASSERT(sin->sin_len == sizeof(*sin), KASSERT(sin->sin_len == sizeof(*sin),

View File

@ -739,8 +739,8 @@ int in_pcbbind_check_bindmulti(const struct inpcb *ni,
void in_pcbpurgeif0(struct inpcbinfo *, struct ifnet *); void in_pcbpurgeif0(struct inpcbinfo *, struct ifnet *);
int in_pcballoc(struct socket *, struct inpcbinfo *); int in_pcballoc(struct socket *, struct inpcbinfo *);
int in_pcbbind(struct inpcb *, struct sockaddr *, struct ucred *); int in_pcbbind(struct inpcb *, struct sockaddr_in *, struct ucred *);
int in_pcbbind_setup(struct inpcb *, struct sockaddr *, in_addr_t *, int in_pcbbind_setup(struct inpcb *, struct sockaddr_in *, in_addr_t *,
u_short *, struct ucred *); u_short *, struct ucred *);
int in_pcbconnect(struct inpcb *, struct sockaddr_in *, struct ucred *, int in_pcbconnect(struct inpcb *, struct sockaddr_in *, struct ucred *,
bool); bool);

View File

@ -245,7 +245,7 @@ tcp_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
tp = intotcpcb(inp); tp = intotcpcb(inp);
#endif #endif
INP_HASH_WLOCK(&V_tcbinfo); INP_HASH_WLOCK(&V_tcbinfo);
error = in_pcbbind(inp, nam, td->td_ucred); error = in_pcbbind(inp, sinp, td->td_ucred);
INP_HASH_WUNLOCK(&V_tcbinfo); INP_HASH_WUNLOCK(&V_tcbinfo);
out: out:
TCP_PROBE2(debug__user, tp, PRU_BIND); TCP_PROBE2(debug__user, tp, PRU_BIND);
@ -309,14 +309,13 @@ tcp6_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
} }
inp->inp_vflag |= INP_IPV4; inp->inp_vflag |= INP_IPV4;
inp->inp_vflag &= ~INP_IPV6; inp->inp_vflag &= ~INP_IPV6;
error = in_pcbbind(inp, (struct sockaddr *)&sin, error = in_pcbbind(inp, &sin, td->td_ucred);
td->td_ucred);
INP_HASH_WUNLOCK(&V_tcbinfo); INP_HASH_WUNLOCK(&V_tcbinfo);
goto out; goto out;
} }
} }
#endif #endif
error = in6_pcbbind(inp, nam, td->td_ucred); error = in6_pcbbind(inp, sin6, td->td_ucred);
INP_HASH_WUNLOCK(&V_tcbinfo); INP_HASH_WUNLOCK(&V_tcbinfo);
out: out:
if (error != 0) if (error != 0)

View File

@ -1206,8 +1206,8 @@ udp_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *addr,
goto release; goto release;
} }
INP_HASH_WLOCK(pcbinfo); INP_HASH_WLOCK(pcbinfo);
error = in_pcbbind_setup(inp, (struct sockaddr *)&src, error = in_pcbbind_setup(inp, &src, &laddr.s_addr, &lport,
&laddr.s_addr, &lport, td->td_ucred); td->td_ucred);
INP_HASH_WUNLOCK(pcbinfo); INP_HASH_WUNLOCK(pcbinfo);
if (error) if (error)
goto release; goto release;
@ -1546,7 +1546,7 @@ udp_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
INP_WLOCK(inp); INP_WLOCK(inp);
INP_HASH_WLOCK(pcbinfo); INP_HASH_WLOCK(pcbinfo);
error = in_pcbbind(inp, nam, td->td_ucred); error = in_pcbbind(inp, sinp, td->td_ucred);
INP_HASH_WUNLOCK(pcbinfo); INP_HASH_WUNLOCK(pcbinfo);
INP_WUNLOCK(inp); INP_WUNLOCK(inp);
return (error); return (error);

View File

@ -153,11 +153,9 @@ in6_pcbsetport(struct in6_addr *laddr, struct inpcb *inp, struct ucred *cred)
} }
int int
in6_pcbbind(struct inpcb *inp, struct sockaddr *nam, in6_pcbbind(struct inpcb *inp, struct sockaddr_in6 *sin6, struct ucred *cred)
struct ucred *cred)
{ {
struct socket *so = inp->inp_socket; struct socket *so = inp->inp_socket;
struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)NULL;
struct inpcbinfo *pcbinfo = inp->inp_pcbinfo; struct inpcbinfo *pcbinfo = inp->inp_pcbinfo;
u_short lport = 0; u_short lport = 0;
int error, lookupflags = 0; int error, lookupflags = 0;
@ -176,12 +174,11 @@ in6_pcbbind(struct inpcb *inp, struct sockaddr *nam,
return (EINVAL); return (EINVAL);
if ((so->so_options & (SO_REUSEADDR|SO_REUSEPORT|SO_REUSEPORT_LB)) == 0) if ((so->so_options & (SO_REUSEADDR|SO_REUSEPORT|SO_REUSEPORT_LB)) == 0)
lookupflags = INPLOOKUP_WILDCARD; lookupflags = INPLOOKUP_WILDCARD;
if (nam == NULL) { if (sin6 == NULL) {
if ((error = prison_local_ip6(cred, &inp->in6p_laddr, if ((error = prison_local_ip6(cred, &inp->in6p_laddr,
((inp->inp_flags & IN6P_IPV6_V6ONLY) != 0))) != 0) ((inp->inp_flags & IN6P_IPV6_V6ONLY) != 0))) != 0)
return (error); return (error);
} else { } else {
sin6 = (struct sockaddr_in6 *)nam;
KASSERT(sin6->sin6_family == AF_INET6, KASSERT(sin6->sin6_family == AF_INET6,
("%s: invalid address family for %p", __func__, sin6)); ("%s: invalid address family for %p", __func__, sin6));
KASSERT(sin6->sin6_len == sizeof(*sin6), KASSERT(sin6->sin6_len == sizeof(*sin6),

View File

@ -73,7 +73,7 @@
void in6_pcbpurgeif0(struct inpcbinfo *, struct ifnet *); void in6_pcbpurgeif0(struct inpcbinfo *, struct ifnet *);
void in6_losing(struct inpcb *); void in6_losing(struct inpcb *);
int in6_pcbbind(struct inpcb *, struct sockaddr *, struct ucred *); int in6_pcbbind(struct inpcb *, struct sockaddr_in6 *, struct ucred *);
int in6_pcbconnect(struct inpcb *, struct sockaddr_in6 *, struct ucred *, int in6_pcbconnect(struct inpcb *, struct sockaddr_in6 *, struct ucred *,
bool); bool);
void in6_pcbdisconnect(struct inpcb *); void in6_pcbdisconnect(struct inpcb *);

View File

@ -1020,6 +1020,7 @@ udp6_attach(struct socket *so, int proto, struct thread *td)
static int static int
udp6_bind(struct socket *so, struct sockaddr *nam, struct thread *td) udp6_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
{ {
struct sockaddr_in6 *sin6_p;
struct inpcb *inp; struct inpcb *inp;
struct inpcbinfo *pcbinfo; struct inpcbinfo *pcbinfo;
int error; int error;
@ -1034,16 +1035,14 @@ udp6_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
if (nam->sa_len != sizeof(struct sockaddr_in6)) if (nam->sa_len != sizeof(struct sockaddr_in6))
return (EINVAL); return (EINVAL);
sin6_p = (struct sockaddr_in6 *)nam;
INP_WLOCK(inp); INP_WLOCK(inp);
INP_HASH_WLOCK(pcbinfo); INP_HASH_WLOCK(pcbinfo);
vflagsav = inp->inp_vflag; vflagsav = inp->inp_vflag;
inp->inp_vflag &= ~INP_IPV4; inp->inp_vflag &= ~INP_IPV4;
inp->inp_vflag |= INP_IPV6; inp->inp_vflag |= INP_IPV6;
if ((inp->inp_flags & IN6P_IPV6_V6ONLY) == 0) { if ((inp->inp_flags & IN6P_IPV6_V6ONLY) == 0) {
struct sockaddr_in6 *sin6_p;
sin6_p = (struct sockaddr_in6 *)nam;
if (IN6_IS_ADDR_UNSPECIFIED(&sin6_p->sin6_addr)) if (IN6_IS_ADDR_UNSPECIFIED(&sin6_p->sin6_addr))
inp->inp_vflag |= INP_IPV4; inp->inp_vflag |= INP_IPV4;
#ifdef INET #ifdef INET
@ -1053,14 +1052,13 @@ udp6_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
in6_sin6_2_sin(&sin, sin6_p); in6_sin6_2_sin(&sin, sin6_p);
inp->inp_vflag |= INP_IPV4; inp->inp_vflag |= INP_IPV4;
inp->inp_vflag &= ~INP_IPV6; inp->inp_vflag &= ~INP_IPV6;
error = in_pcbbind(inp, (struct sockaddr *)&sin, error = in_pcbbind(inp, &sin, td->td_ucred);
td->td_ucred);
goto out; goto out;
} }
#endif #endif
} }
error = in6_pcbbind(inp, nam, td->td_ucred); error = in6_pcbbind(inp, sin6_p, td->td_ucred);
#ifdef INET #ifdef INET
out: out:
#endif #endif