From 31f555a1c593bf83c273033527ed5fe5aec02d28 Mon Sep 17 00:00:00 2001
From: Robert Watson <rwatson@FreeBSD.org>
Date: Sat, 19 Jun 2004 03:23:14 +0000
Subject: [PATCH] Assert socket buffer lock in sb_lock() to protect socket
 buffer sleep lock state.  Convert tsleep() into msleep() with socket buffer
 mutex as argument.  Hold socket buffer lock over sbunlock() to protect sleep
 lock state.

Assert socket buffer lock in sbwait() to protect the socket buffer
wait state.  Convert tsleep() into msleep() with socket buffer mutex
as argument.

Modify sofree(), sosend(), and soreceive() to acquire SOCKBUF_LOCK()
in order to call into these functions with the lock, as well as to
start protecting other socket buffer use in their implementation.  Drop
the socket buffer mutexes around calls into the protocol layer, around
potentially blocking operations, for copying to/from user space, and
VM operations relating to zero-copy.  Assert the socket buffer mutex
strategically after code sections or at the beginning of loops.  In
some cases, modify return code to ensure locks are properly dropped.

Convert the potentially blocking allocation of storage for the remote
address in soreceive() into a non-blocking allocation; we may wish to
move the allocation earlier so that it can block prior to acquisition
of the socket buffer lock.

Drop some spl use.

NOTE: Some races exist in the current structuring of sosend() and
soreceive().  This commit only merges basic socket locking in this
code; follow-up commits will close additional races.  As merged,
these changes are not sufficient to run without Giant safely.

Reviewed by:	juli, tjr
---
 sys/kern/uipc_sockbuf.c  |  8 +++-
 sys/kern/uipc_socket.c   | 85 ++++++++++++++++++++++++++++------------
 sys/kern/uipc_socket2.c  |  8 +++-
 sys/kern/uipc_syscalls.c | 21 ++++++++++
 4 files changed, 93 insertions(+), 29 deletions(-)

diff --git a/sys/kern/uipc_sockbuf.c b/sys/kern/uipc_sockbuf.c
index c3852535c5c9..0e4e99bb684e 100644
--- a/sys/kern/uipc_sockbuf.c
+++ b/sys/kern/uipc_sockbuf.c
@@ -323,8 +323,10 @@ sbwait(sb)
 	struct sockbuf *sb;
 {
 
+	SOCKBUF_LOCK_ASSERT(sb);
+
 	sb->sb_flags |= SB_WAIT;
-	return (tsleep(&sb->sb_cc,
+	return (msleep(&sb->sb_cc, &sb->sb_mtx,
 	    (sb->sb_flags & SB_NOINTR) ? PSOCK : PSOCK | PCATCH, "sbwait",
 	    sb->sb_timeo));
 }
@@ -339,9 +341,11 @@ sb_lock(sb)
 {
 	int error;
 
+	SOCKBUF_LOCK_ASSERT(sb);
+
 	while (sb->sb_flags & SB_LOCK) {
 		sb->sb_flags |= SB_WANT;
-		error = tsleep(&sb->sb_flags,
+		error = msleep(&sb->sb_flags, &sb->sb_mtx,
 		    (sb->sb_flags & SB_NOINTR) ? PSOCK : PSOCK|PCATCH,
 		    "sblock", 0);
 		if (error)
diff --git a/sys/kern/uipc_socket.c b/sys/kern/uipc_socket.c
index f2f7d1d55e0d..5d5d1e36f703 100644
--- a/sys/kern/uipc_socket.c
+++ b/sys/kern/uipc_socket.c
@@ -295,7 +295,6 @@ sofree(so)
 	struct socket *so;
 {
 	struct socket *head;
-	int s;
 
 	KASSERT(so->so_count == 0, ("socket %p so_count not 0", so));
 	SOCK_LOCK_ASSERT(so);
@@ -341,13 +340,13 @@ sofree(so)
 	    ("sofree: so_head == NULL, but still SQ_COMP(%d) or SQ_INCOMP(%d)",
 	    so->so_qstate & SQ_COMP, so->so_qstate & SQ_INCOMP));
 	ACCEPT_UNLOCK();
+	SOCKBUF_LOCK(&so->so_snd);
 	so->so_snd.sb_flags |= SB_NOINTR;
 	(void)sblock(&so->so_snd, M_WAITOK);
-	s = splimp();
 	socantsendmore(so);
-	splx(s);
 	sbunlock(&so->so_snd);
 	sbrelease(&so->so_snd, so);
+	SOCKBUF_UNLOCK(&so->so_snd);
 	sorflush(so);
 	sodealloc(so);
 }
@@ -597,11 +596,14 @@ sosend(so, addr, uio, top, control, flags, td)
 		clen = control->m_len;
 #define	snderr(errno)	{ error = (errno); splx(s); goto release; }
 
+	SOCKBUF_LOCK(&so->so_snd);
 restart:
+	SOCKBUF_LOCK_ASSERT(&so->so_snd);
 	error = sblock(&so->so_snd, SBLOCKWAIT(flags));
 	if (error)
-		goto out;
+		goto out_locked;
 	do {
+		SOCKBUF_LOCK_ASSERT(&so->so_snd);
 		s = splnet();
 		if (so->so_snd.sb_state & SBS_CANTSENDMORE)
 			snderr(EPIPE);
@@ -641,9 +643,10 @@ restart:
 			error = sbwait(&so->so_snd);
 			splx(s);
 			if (error)
-				goto out;
+				goto out_locked;
 			goto restart;
 		}
+		SOCKBUF_UNLOCK(&so->so_snd);
 		splx(s);
 		mp = &top;
 		space -= clen;
@@ -665,6 +668,7 @@ restart:
 					MGETHDR(m, M_TRYWAIT, MT_DATA);
 					if (m == NULL) {
 						error = ENOBUFS;
+						SOCKBUF_LOCK(&so->so_snd);
 						goto release;
 					}
 					m->m_pkthdr.len = 0;
@@ -673,6 +677,7 @@ restart:
 					MGET(m, M_TRYWAIT, MT_DATA);
 					if (m == NULL) {
 						error = ENOBUFS;
+						SOCKBUF_LOCK(&so->so_snd);
 						goto release;
 					}
 				}
@@ -726,6 +731,7 @@ restart:
 			}
 			if (m == NULL) {
 				error = ENOBUFS;
+				SOCKBUF_LOCK(&so->so_snd);
 				goto release;
 			}
 
@@ -740,8 +746,10 @@ restart:
 			m->m_len = len;
 			*mp = m;
 			top->m_pkthdr.len += len;
-			if (error)
+			if (error) {
+				SOCKBUF_LOCK(&so->so_snd);
 				goto release;
+			}
 			mp = &m->m_next;
 			if (resid <= 0) {
 				if (flags & MSG_EOR)
@@ -787,13 +795,20 @@ restart:
 		    control = NULL;
 		    top = NULL;
 		    mp = &top;
-		    if (error)
+		    if (error) {
+			SOCKBUF_LOCK(&so->so_snd);
 			goto release;
+		    }
 		} while (resid && space > 0);
+		SOCKBUF_LOCK(&so->so_snd);
 	} while (resid);
 
 release:
+	SOCKBUF_LOCK_ASSERT(&so->so_snd);
 	sbunlock(&so->so_snd);
+out_locked:
+	SOCKBUF_LOCK_ASSERT(&so->so_snd);
+	SOCKBUF_UNLOCK(&so->so_snd);
 out:
 	if (top != NULL)
 		m_freem(top);
@@ -886,10 +901,12 @@ bad:
 	if (so->so_state & SS_ISCONFIRMING && uio->uio_resid)
 		(*pr->pr_usrreqs->pru_rcvd)(so, 0);
 
+	SOCKBUF_LOCK(&so->so_rcv);
 restart:
+	SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 	error = sblock(&so->so_rcv, SBLOCKWAIT(flags));
 	if (error)
-		return (error);
+		goto out;
 	s = splnet();
 
 	m = so->so_rcv.sb_mb;
@@ -949,10 +966,11 @@ restart:
 		error = sbwait(&so->so_rcv);
 		splx(s);
 		if (error)
-			return (error);
+			goto out;
 		goto restart;
 	}
 dontblock:
+	SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 	if (uio->uio_td)
 		uio->uio_td->td_proc->p_stats->p_ru.ru_msgrcv++;
 	SBLASTRECORDCHK(&so->so_rcv);
@@ -964,7 +982,7 @@ dontblock:
 		orig_resid = 0;
 		if (psa != NULL)
 			*psa = sodupsockaddr(mtod(m, struct sockaddr *),
-			    mp0 == NULL ? M_WAITOK : M_NOWAIT);
+			    M_NOWAIT);
 		if (flags & MSG_PEEK) {
 			m = m->m_next;
 		} else {
@@ -982,10 +1000,12 @@ dontblock:
 			sbfree(&so->so_rcv, m);
 			so->so_rcv.sb_mb = m->m_next;
 			m->m_next = NULL;
-			if (pr->pr_domain->dom_externalize)
-				error =
-				(*pr->pr_domain->dom_externalize)(m, controlp);
-			else if (controlp != NULL)
+			if (pr->pr_domain->dom_externalize) {
+				SOCKBUF_UNLOCK(&so->so_rcv);
+				error = (*pr->pr_domain->dom_externalize)
+				    (m, controlp);
+				SOCKBUF_LOCK(&so->so_rcv);
+			} else if (controlp != NULL)
 				*controlp = m;
 			else
 				m_freem(m);
@@ -1021,12 +1041,14 @@ dontblock:
 			SB_EMPTY_FIXUP(&so->so_rcv);
 		}
 	}
+	SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 	SBLASTRECORDCHK(&so->so_rcv);
 	SBLASTMBUFCHK(&so->so_rcv);
 
 	moff = 0;
 	offset = 0;
 	while (m != NULL && uio->uio_resid > 0 && error == 0) {
+		SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 		if (m->m_type == MT_OOBDATA) {
 			if (type != MT_OOBDATA)
 				break;
@@ -1050,8 +1072,10 @@ dontblock:
 		 * block interrupts again.
 		 */
 		if (mp == NULL) {
+			SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 			SBLASTRECORDCHK(&so->so_rcv);
 			SBLASTMBUFCHK(&so->so_rcv);
+			SOCKBUF_UNLOCK(&so->so_rcv);
 			splx(s);
 #ifdef ZERO_COPY_SOCKETS
 			if (so_zero_copy_receive) {
@@ -1076,6 +1100,7 @@ dontblock:
 			} else
 #endif /* ZERO_COPY_SOCKETS */
 			error = uiomove(mtod(m, char *) + moff, (int)len, uio);
+			SOCKBUF_LOCK(&so->so_rcv);
 			s = splnet();
 			if (error)
 				goto release;
@@ -1125,6 +1150,7 @@ dontblock:
 			if ((flags & MSG_PEEK) == 0) {
 				so->so_oobmark -= len;
 				if (so->so_oobmark == 0) {
+					SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 					so->so_rcv.sb_state |= SBS_RCVATMARK;
 					break;
 				}
@@ -1145,22 +1171,23 @@ dontblock:
 		 */
 		while (flags & MSG_WAITALL && m == NULL && uio->uio_resid > 0 &&
 		    !sosendallatonce(so) && nextrecord == NULL) {
+			SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 			if (so->so_error || so->so_rcv.sb_state & SBS_CANTRCVMORE)
 				break;
 			/*
 			 * Notify the protocol that some data has been
 			 * drained before blocking.
 			 */
-			if (pr->pr_flags & PR_WANTRCVD && so->so_pcb != NULL)
+			if (pr->pr_flags & PR_WANTRCVD && so->so_pcb != NULL) {
+				SOCKBUF_UNLOCK(&so->so_rcv);
 				(*pr->pr_usrreqs->pru_rcvd)(so, flags);
+				SOCKBUF_LOCK(&so->so_rcv);
+			}
 			SBLASTRECORDCHK(&so->so_rcv);
 			SBLASTMBUFCHK(&so->so_rcv);
 			error = sbwait(&so->so_rcv);
-			if (error) {
-				sbunlock(&so->so_rcv);
-				splx(s);
-				return (0);
-			}
+			if (error)
+				goto release;
 			m = so->so_rcv.sb_mb;
 			if (m != NULL)
 				nextrecord = m->m_nextpkt;
@@ -1169,8 +1196,10 @@ dontblock:
 
 	if (m != NULL && pr->pr_flags & PR_ATOMIC) {
 		flags |= MSG_TRUNC;
-		if ((flags & MSG_PEEK) == 0)
+		if ((flags & MSG_PEEK) == 0) {
+			SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 			(void) sbdroprecord(&so->so_rcv);
+		}
 	}
 	if ((flags & MSG_PEEK) == 0) {
 		if (m == NULL) {
@@ -1188,9 +1217,13 @@ dontblock:
 		}
 		SBLASTRECORDCHK(&so->so_rcv);
 		SBLASTMBUFCHK(&so->so_rcv);
-		if (pr->pr_flags & PR_WANTRCVD && so->so_pcb)
+		if (pr->pr_flags & PR_WANTRCVD && so->so_pcb) {
+			SOCKBUF_UNLOCK(&so->so_rcv);
 			(*pr->pr_usrreqs->pru_rcvd)(so, flags);
+			SOCKBUF_LOCK(&so->so_rcv);
+		}
 	}
+	SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 	if (orig_resid == uio->uio_resid && orig_resid &&
 	    (flags & MSG_EOR) == 0 && (so->so_rcv.sb_state & SBS_CANTRCVMORE) == 0) {
 		sbunlock(&so->so_rcv);
@@ -1201,7 +1234,10 @@ dontblock:
 	if (flagsp != NULL)
 		*flagsp |= flags;
 release:
+	SOCKBUF_LOCK_ASSERT(&so->so_rcv);
 	sbunlock(&so->so_rcv);
+out:
+	SOCKBUF_UNLOCK(&so->so_rcv);
 	splx(s);
 	return (error);
 }
@@ -1229,12 +1265,11 @@ sorflush(so)
 {
 	struct sockbuf *sb = &so->so_rcv;
 	struct protosw *pr = so->so_proto;
-	int s;
 	struct sockbuf asb;
 
+	SOCKBUF_LOCK(sb);
 	sb->sb_flags |= SB_NOINTR;
 	(void) sblock(sb, M_WAITOK);
-	s = splimp();
 	socantrcvmore(so);
 	sbunlock(sb);
 	asb = *sb;
@@ -1244,7 +1279,7 @@ sorflush(so)
 	 */
 	bzero(&sb->sb_startzero,
 	    sizeof(*sb) - offsetof(struct sockbuf, sb_startzero));
-	splx(s);
+	SOCKBUF_UNLOCK(sb);
 
 	if (pr->pr_flags & PR_RIGHTS && pr->pr_domain->dom_dispose != NULL)
 		(*pr->pr_domain->dom_dispose)(asb.sb_mb);
diff --git a/sys/kern/uipc_socket2.c b/sys/kern/uipc_socket2.c
index c3852535c5c9..0e4e99bb684e 100644
--- a/sys/kern/uipc_socket2.c
+++ b/sys/kern/uipc_socket2.c
@@ -323,8 +323,10 @@ sbwait(sb)
 	struct sockbuf *sb;
 {
 
+	SOCKBUF_LOCK_ASSERT(sb);
+
 	sb->sb_flags |= SB_WAIT;
-	return (tsleep(&sb->sb_cc,
+	return (msleep(&sb->sb_cc, &sb->sb_mtx,
 	    (sb->sb_flags & SB_NOINTR) ? PSOCK : PSOCK | PCATCH, "sbwait",
 	    sb->sb_timeo));
 }
@@ -339,9 +341,11 @@ sb_lock(sb)
 {
 	int error;
 
+	SOCKBUF_LOCK_ASSERT(sb);
+
 	while (sb->sb_flags & SB_LOCK) {
 		sb->sb_flags |= SB_WANT;
-		error = tsleep(&sb->sb_flags,
+		error = msleep(&sb->sb_flags, &sb->sb_mtx,
 		    (sb->sb_flags & SB_NOINTR) ? PSOCK : PSOCK|PCATCH,
 		    "sblock", 0);
 		if (error)
diff --git a/sys/kern/uipc_syscalls.c b/sys/kern/uipc_syscalls.c
index 5ebb7ddf51b6..30d23438267d 100644
--- a/sys/kern/uipc_syscalls.c
+++ b/sys/kern/uipc_syscalls.c
@@ -1801,7 +1801,9 @@ do_sendfile(struct thread *td, struct sendfile_args *uap, int compat)
 	/*
 	 * Protect against multiple writers to the socket.
 	 */
+	SOCKBUF_LOCK(&so->so_snd);
 	(void) sblock(&so->so_snd, M_WAITOK);
+	SOCKBUF_UNLOCK(&so->so_snd);
 
 	/*
 	 * Loop through the pages in the file, starting with the requested
@@ -1841,14 +1843,17 @@ retry_lookup:
 		 * Optimize the non-blocking case by looking at the socket space
 		 * before going to the extra work of constituting the sf_buf.
 		 */
+		SOCKBUF_LOCK(&so->so_snd);
 		if ((so->so_state & SS_NBIO) && sbspace(&so->so_snd) <= 0) {
 			if (so->so_snd.sb_state & SBS_CANTSENDMORE)
 				error = EPIPE;
 			else
 				error = EAGAIN;
 			sbunlock(&so->so_snd);
+			SOCKBUF_UNLOCK(&so->so_snd);
 			goto done;
 		}
+		SOCKBUF_UNLOCK(&so->so_snd);
 		VM_OBJECT_LOCK(obj);
 		/*
 		 * Attempt to look up the page.
@@ -1936,7 +1941,9 @@ retry_lookup:
 			}
 			vm_page_unlock_queues();
 			VM_OBJECT_UNLOCK(obj);
+			SOCKBUF_LOCK(&so->so_snd);
 			sbunlock(&so->so_snd);
+			SOCKBUF_UNLOCK(&so->so_snd);
 			goto done;
 		}
 		vm_page_unlock_queues();
@@ -1952,7 +1959,9 @@ retry_lookup:
 			if (pg->wire_count == 0 && pg->object == NULL)
 				vm_page_free(pg);
 			vm_page_unlock_queues();
+			SOCKBUF_LOCK(&so->so_snd);
 			sbunlock(&so->so_snd);
+			SOCKBUF_UNLOCK(&so->so_snd);
 			error = EINTR;
 			goto done;
 		}
@@ -1967,7 +1976,9 @@ retry_lookup:
 		if (m == NULL) {
 			error = ENOBUFS;
 			sf_buf_mext((void *)sf_buf_kva(sf), sf);
+			SOCKBUF_LOCK(&so->so_snd);
 			sbunlock(&so->so_snd);
+			SOCKBUF_UNLOCK(&so->so_snd);
 			goto done;
 		}
 		/*
@@ -1989,6 +2000,7 @@ retry_lookup:
 		 * Add the buffer to the socket buffer chain.
 		 */
 		s = splnet();
+		SOCKBUF_LOCK(&so->so_snd);
 retry_space:
 		/*
 		 * Make sure that the socket is still able to take more data.
@@ -2001,6 +2013,7 @@ retry_space:
 		 * blocks before the pru_send (or more accurately, any blocking
 		 * results in a loop back to here to re-check).
 		 */
+		SOCKBUF_LOCK_ASSERT(&so->so_snd);
 		if ((so->so_snd.sb_state & SBS_CANTSENDMORE) || so->so_error) {
 			if (so->so_snd.sb_state & SBS_CANTSENDMORE) {
 				error = EPIPE;
@@ -2010,6 +2023,7 @@ retry_space:
 			}
 			m_freem(m);
 			sbunlock(&so->so_snd);
+			SOCKBUF_UNLOCK(&so->so_snd);
 			splx(s);
 			goto done;
 		}
@@ -2022,6 +2036,7 @@ retry_space:
 			if (so->so_state & SS_NBIO) {
 				m_freem(m);
 				sbunlock(&so->so_snd);
+				SOCKBUF_UNLOCK(&so->so_snd);
 				splx(s);
 				error = EAGAIN;
 				goto done;
@@ -2035,20 +2050,26 @@ retry_space:
 			if (error) {
 				m_freem(m);
 				sbunlock(&so->so_snd);
+				SOCKBUF_UNLOCK(&so->so_snd);
 				splx(s);
 				goto done;
 			}
 			goto retry_space;
 		}
+		SOCKBUF_UNLOCK(&so->so_snd);
 		error = (*so->so_proto->pr_usrreqs->pru_send)(so, 0, m, 0, 0, td);
 		splx(s);
 		if (error) {
+			SOCKBUF_LOCK(&so->so_snd);
 			sbunlock(&so->so_snd);
+			SOCKBUF_UNLOCK(&so->so_snd);
 			goto done;
 		}
 		headersent = 1;
 	}
+	SOCKBUF_LOCK(&so->so_snd);
 	sbunlock(&so->so_snd);
+	SOCKBUF_UNLOCK(&so->so_snd);
 
 	/*
 	 * Send trailers. Wimp out and use writev(2).