From 86156495f43b985fde53c3b2528227d5a1f9170a Mon Sep 17 00:00:00 2001
From: Hans Petter Selasky <hselasky@FreeBSD.org>
Date: Fri, 7 Sep 2018 18:05:09 +0000
Subject: [PATCH] Implement get network interface by params function in ipoib.

Also fix the validate_ipv4_net_dev() and validate_ipv6_net_dev() functions
which had source and destination addresses swapped, and didn't set the
scope ID for IPv6 link-local addresses.

This allows applications like krping to work using IPoIB devices.

MFC after:		3 days
Approved by:		re (gjb)
Sponsored by:		Mellanox Technologies
---
 sys/ofed/drivers/infiniband/core/ib_cma.c     |  49 +++---
 .../infiniband/core/ib_roce_gid_mgmt.c        |  28 +--
 .../drivers/infiniband/ulp/ipoib/ipoib_main.c | 159 +++++++++++++++++-
 3 files changed, 197 insertions(+), 39 deletions(-)

diff --git a/sys/ofed/drivers/infiniband/core/ib_cma.c b/sys/ofed/drivers/infiniband/core/ib_cma.c
index 42e7b3ad811f..eb2053b6d4c3 100644
--- a/sys/ofed/drivers/infiniband/core/ib_cma.c
+++ b/sys/ofed/drivers/infiniband/core/ib_cma.c
@@ -1263,10 +1263,10 @@ static bool validate_ipv4_net_dev(struct net_device *net_dev,
 				  const struct sockaddr_in *src_addr)
 {
 #ifdef INET
-	struct sockaddr_in dst_tmp = *dst_addr;
+	struct sockaddr_in src_tmp = *src_addr;
 	__be32 daddr = dst_addr->sin_addr.s_addr,
 	       saddr = src_addr->sin_addr.s_addr;
-	struct net_device *src_dev;
+	struct net_device *dst_dev;
 	struct rtentry *rte;
 	bool ret;
 
@@ -1276,29 +1276,29 @@ static bool validate_ipv4_net_dev(struct net_device *net_dev,
 	    ipv4_is_loopback(saddr))
 		return false;
 
-	src_dev = ip_dev_find(net_dev->if_vnet, saddr);
-	if (src_dev != net_dev) {
-		if (src_dev != NULL)
-			dev_put(src_dev);
+	dst_dev = ip_dev_find(net_dev->if_vnet, daddr);
+	if (dst_dev != net_dev) {
+		if (dst_dev != NULL)
+			dev_put(dst_dev);
 		return false;
 	}
-	dev_put(src_dev);
+	dev_put(dst_dev);
 
 	/*
 	 * Make sure the socket address length field
 	 * is set, else rtalloc1() will fail.
 	 */
-	dst_tmp.sin_len = sizeof(dst_tmp);
+	src_tmp.sin_len = sizeof(src_tmp);
 
 	CURVNET_SET(net_dev->if_vnet);
-	rte = rtalloc1((struct sockaddr *)&dst_tmp, 1, 0);
-	CURVNET_RESTORE();
+	rte = rtalloc1((struct sockaddr *)&src_tmp, 1, 0);
 	if (rte != NULL) {
 		ret = (rte->rt_ifp == net_dev);
 		RTFREE_LOCKED(rte);
 	} else {
 		ret = false;
 	}
+	CURVNET_RESTORE();
 	return ret;
 #else
 	return false;
@@ -1310,31 +1310,42 @@ static bool validate_ipv6_net_dev(struct net_device *net_dev,
 				  const struct sockaddr_in6 *src_addr)
 {
 #ifdef INET6
-	struct sockaddr_in6 dst_tmp = *dst_addr;
-	struct in6_addr in6_addr = src_addr->sin6_addr;
-	struct net_device *src_dev;
+	struct sockaddr_in6 src_tmp = *src_addr;
+	struct in6_addr in6_addr = dst_addr->sin6_addr;
+	struct net_device *dst_dev;
 	struct rtentry *rte;
 	bool ret;
 
-	src_dev = ip6_dev_find(net_dev->if_vnet, in6_addr);
-	if (src_dev != net_dev)
+	dst_dev = ip6_dev_find(net_dev->if_vnet, in6_addr);
+	if (dst_dev != net_dev) {
+		if (dst_dev != NULL)
+			dev_put(dst_dev);
 		return false;
+	}
+
+	CURVNET_SET(net_dev->if_vnet);
 
 	/*
 	 * Make sure the socket address length field
 	 * is set, else rtalloc1() will fail.
 	 */
-	dst_tmp.sin6_len = sizeof(dst_tmp);
+	src_tmp.sin6_len = sizeof(src_tmp);
 
-	CURVNET_SET(net_dev->if_vnet);
-	rte = rtalloc1((struct sockaddr *)&dst_tmp, 1, 0);
-	CURVNET_RESTORE();
+	/*
+	 * Make sure the scope ID gets embedded, else rtalloc1() will
+	 * resolve to the loopback interface.
+	 */
+	src_tmp.sin6_scope_id = net_dev->if_index;
+	sa6_embedscope(&src_tmp, 0);
+
+	rte = rtalloc1((struct sockaddr *)&src_tmp, 1, 0);
 	if (rte != NULL) {
 		ret = (rte->rt_ifp == net_dev);
 		RTFREE_LOCKED(rte);
 	} else {
 		ret = false;
 	}
+	CURVNET_RESTORE();
 	return ret;
 #else
 	return false;
diff --git a/sys/ofed/drivers/infiniband/core/ib_roce_gid_mgmt.c b/sys/ofed/drivers/infiniband/core/ib_roce_gid_mgmt.c
index 5a65207a82f3..1dae52bac08a 100644
--- a/sys/ofed/drivers/infiniband/core/ib_roce_gid_mgmt.c
+++ b/sys/ofed/drivers/infiniband/core/ib_roce_gid_mgmt.c
@@ -149,16 +149,6 @@ roce_gid_enum_netdev_default(struct ib_device *ib_dev,
 	return (hweight_long(gid_type_mask));
 }
 
-#define ETH_IPOIB_DRV_NAME	"ib"
-
-static inline int
-is_eth_ipoib_intf(struct net_device *dev)
-{
-	if (strcmp(dev->if_dname, ETH_IPOIB_DRV_NAME))
-		return 0;
-	return 1;
-}
-
 static void
 roce_gid_update_addr_callback(struct ib_device *device, u8 port,
     struct net_device *ndev, void *cookie)
@@ -322,15 +312,15 @@ roce_gid_queue_scan_event(struct net_device *ndev)
 	struct roce_netdev_event_work *work;
 
 retry:
-	if (is_eth_ipoib_intf(ndev))
-		return;
-
-	if (ndev->if_type != IFT_ETHER) {
-		if (ndev->if_type == IFT_L2VLAN) {
-			ndev = rdma_vlan_dev_real_dev(ndev);
-			if (ndev != NULL)
-				goto retry;
-		}
+	switch (ndev->if_type) {
+	case IFT_ETHER:
+		break;
+	case IFT_L2VLAN:
+		ndev = rdma_vlan_dev_real_dev(ndev);
+		if (ndev != NULL)
+			goto retry;
+		/* FALLTHROUGH */
+	default:
 		return;
 	}
 
diff --git a/sys/ofed/drivers/infiniband/ulp/ipoib/ipoib_main.c b/sys/ofed/drivers/infiniband/ulp/ipoib/ipoib_main.c
index aee91c29eb0a..6a8995a98d09 100644
--- a/sys/ofed/drivers/infiniband/ulp/ipoib/ipoib_main.c
+++ b/sys/ofed/drivers/infiniband/ulp/ipoib/ipoib_main.c
@@ -54,6 +54,8 @@ static	int ipoib_resolvemulti(struct ifnet *, struct sockaddr **,
 #include <net/ip.h>
 #include <net/ipv6.h>
 
+#include <rdma/ib_cache.h>
+
 MODULE_AUTHOR("Roland Dreier");
 MODULE_DESCRIPTION("IP-over-InfiniBand net driver");
 MODULE_LICENSE("Dual BSD/GPL");
@@ -90,6 +92,10 @@ struct ib_sa_client ipoib_sa_client;
 
 static void ipoib_add_one(struct ib_device *device);
 static void ipoib_remove_one(struct ib_device *device, void *client_data);
+static struct net_device *ipoib_get_net_dev_by_params(
+		struct ib_device *dev, u8 port, u16 pkey,
+		const union ib_gid *gid, const struct sockaddr *addr,
+		void *client_data);
 static void ipoib_start(struct ifnet *dev);
 static int ipoib_output(struct ifnet *ifp, struct mbuf *m,
 	    const struct sockaddr *dst, struct route *ro);
@@ -163,7 +169,8 @@ ipoib_mtap_proto(struct ifnet *ifp, struct mbuf *mb, uint16_t proto)
 static struct ib_client ipoib_client = {
 	.name   = "ipoib",
 	.add    = ipoib_add_one,
-	.remove = ipoib_remove_one
+	.remove = ipoib_remove_one,
+	.get_net_dev_by_params = ipoib_get_net_dev_by_params,
 };
 
 int
@@ -1113,6 +1120,156 @@ ipoib_remove_one(struct ib_device *device, void *client_data)
 	kfree(dev_list);
 }
 
+static int
+ipoib_match_dev_addr(const struct sockaddr *addr, struct net_device *dev)
+{
+	struct ifaddr *ifa;
+	int retval = 0;
+
+	CURVNET_SET(dev->if_vnet);
+	IF_ADDR_RLOCK(dev);
+	CK_STAILQ_FOREACH(ifa, &dev->if_addrhead, ifa_link) {
+		if (ifa->ifa_addr == NULL ||
+		    ifa->ifa_addr->sa_family != addr->sa_family ||
+		    ifa->ifa_addr->sa_len != addr->sa_len) {
+			continue;
+		}
+		if (memcmp(ifa->ifa_addr, addr, addr->sa_len) == 0) {
+			retval = 1;
+			break;
+		}
+	}
+	IF_ADDR_RUNLOCK(dev);
+	CURVNET_RESTORE();
+
+	return (retval);
+}
+
+/*
+ * ipoib_match_gid_pkey_addr - returns the number of IPoIB netdevs on
+ * top a given ipoib device matching a pkey_index and address, if one
+ * exists.
+ *
+ * @found_net_dev: contains a matching net_device if the return value
+ * >= 1, with a reference held.
+ */
+static int
+ipoib_match_gid_pkey_addr(struct ipoib_dev_priv *priv,
+    const union ib_gid *gid, u16 pkey_index, const struct sockaddr *addr,
+    struct net_device **found_net_dev)
+{
+	struct ipoib_dev_priv *child_priv;
+	int matches = 0;
+
+	if (priv->pkey_index == pkey_index &&
+	    (!gid || !memcmp(gid, &priv->local_gid, sizeof(*gid)))) {
+		if (addr == NULL || ipoib_match_dev_addr(addr, priv->dev) != 0) {
+			if (*found_net_dev == NULL) {
+				struct net_device *net_dev;
+
+				if (priv->parent != NULL)
+					net_dev = priv->parent;
+				else
+					net_dev = priv->dev;
+				*found_net_dev = net_dev;
+				dev_hold(net_dev);
+			}
+			matches++;
+		}
+	}
+
+	/* Check child interfaces */
+	mutex_lock(&priv->vlan_mutex);
+	list_for_each_entry(child_priv, &priv->child_intfs, list) {
+		matches += ipoib_match_gid_pkey_addr(child_priv, gid,
+		    pkey_index, addr, found_net_dev);
+		if (matches > 1)
+			break;
+	}
+	mutex_unlock(&priv->vlan_mutex);
+
+	return matches;
+}
+
+/*
+ * __ipoib_get_net_dev_by_params - returns the number of matching
+ * net_devs found (between 0 and 2). Also return the matching
+ * net_device in the @net_dev parameter, holding a reference to the
+ * net_device, if the number of matches >= 1
+ */
+static int
+__ipoib_get_net_dev_by_params(struct list_head *dev_list, u8 port,
+    u16 pkey_index, const union ib_gid *gid,
+    const struct sockaddr *addr, struct net_device **net_dev)
+{
+	struct ipoib_dev_priv *priv;
+	int matches = 0;
+
+	*net_dev = NULL;
+
+	list_for_each_entry(priv, dev_list, list) {
+		if (priv->port != port)
+			continue;
+
+		matches += ipoib_match_gid_pkey_addr(priv, gid, pkey_index,
+		    addr, net_dev);
+
+		if (matches > 1)
+			break;
+	}
+
+	return matches;
+}
+
+static struct net_device *
+ipoib_get_net_dev_by_params(struct ib_device *dev, u8 port, u16 pkey,
+    const union ib_gid *gid, const struct sockaddr *addr, void *client_data)
+{
+	struct net_device *net_dev;
+	struct list_head *dev_list = client_data;
+	u16 pkey_index;
+	int matches;
+	int ret;
+
+	if (!rdma_protocol_ib(dev, port))
+		return NULL;
+
+	ret = ib_find_cached_pkey(dev, port, pkey, &pkey_index);
+	if (ret)
+		return NULL;
+
+	if (!dev_list)
+		return NULL;
+
+	/* See if we can find a unique device matching the L2 parameters */
+	matches = __ipoib_get_net_dev_by_params(dev_list, port, pkey_index,
+						gid, NULL, &net_dev);
+
+	switch (matches) {
+	case 0:
+		return NULL;
+	case 1:
+		return net_dev;
+	}
+
+	dev_put(net_dev);
+
+	/* Couldn't find a unique device with L2 parameters only. Use L3
+	 * address to uniquely match the net device */
+	matches = __ipoib_get_net_dev_by_params(dev_list, port, pkey_index,
+						gid, addr, &net_dev);
+	switch (matches) {
+	case 0:
+		return NULL;
+	default:
+		dev_warn_ratelimited(&dev->dev,
+				     "duplicate IP address detected\n");
+		/* Fall through */
+	case 1:
+		return net_dev;
+	}
+}
+
 static void
 ipoib_config_vlan(void *arg, struct ifnet *ifp, u_int16_t vtag)
 {