diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index 8adab6335ced9..e927422d8c585 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -278,6 +278,55 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
 }
 EXPORT_SYMBOL_GPL(l2tp_session_find);
 
+/* Like l2tp_session_find() but takes a reference on the returned session.
+ * Optionally calls session->ref() too if do_ref is true.
+ */
+struct l2tp_session *l2tp_session_get(struct net *net,
+				      struct l2tp_tunnel *tunnel,
+				      u32 session_id, bool do_ref)
+{
+	struct hlist_head *session_list;
+	struct l2tp_session *session;
+
+	if (!tunnel) {
+		struct l2tp_net *pn = l2tp_pernet(net);
+
+		session_list = l2tp_session_id_hash_2(pn, session_id);
+
+		rcu_read_lock_bh();
+		hlist_for_each_entry_rcu(session, session_list, global_hlist) {
+			if (session->session_id == session_id) {
+				l2tp_session_inc_refcount(session);
+				if (do_ref && session->ref)
+					session->ref(session);
+				rcu_read_unlock_bh();
+
+				return session;
+			}
+		}
+		rcu_read_unlock_bh();
+
+		return NULL;
+	}
+
+	session_list = l2tp_session_id_hash(tunnel, session_id);
+	read_lock_bh(&tunnel->hlist_lock);
+	hlist_for_each_entry(session, session_list, hlist) {
+		if (session->session_id == session_id) {
+			l2tp_session_inc_refcount(session);
+			if (do_ref && session->ref)
+				session->ref(session);
+			read_unlock_bh(&tunnel->hlist_lock);
+
+			return session;
+		}
+	}
+	read_unlock_bh(&tunnel->hlist_lock);
+
+	return NULL;
+}
+EXPORT_SYMBOL_GPL(l2tp_session_get);
+
 struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
 {
 	int hash;
@@ -303,7 +352,8 @@ EXPORT_SYMBOL_GPL(l2tp_session_find_nth);
 /* Lookup a session by interface name.
  * This is very inefficient but is only used by management interfaces.
  */
-struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
+struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
+						bool do_ref)
 {
 	struct l2tp_net *pn = l2tp_pernet(net);
 	int hash;
@@ -313,7 +363,11 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
 	for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
 		hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
 			if (!strcmp(session->ifname, ifname)) {
+				l2tp_session_inc_refcount(session);
+				if (do_ref && session->ref)
+					session->ref(session);
 				rcu_read_unlock_bh();
+
 				return session;
 			}
 		}
@@ -323,7 +377,49 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
 
 	return NULL;
 }
-EXPORT_SYMBOL_GPL(l2tp_session_find_by_ifname);
+EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);
+
+static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
+				      struct l2tp_session *session)
+{
+	struct l2tp_session *session_walk;
+	struct hlist_head *g_head;
+	struct hlist_head *head;
+	struct l2tp_net *pn;
+
+	head = l2tp_session_id_hash(tunnel, session->session_id);
+
+	write_lock_bh(&tunnel->hlist_lock);
+	hlist_for_each_entry(session_walk, head, hlist)
+		if (session_walk->session_id == session->session_id)
+			goto exist;
+
+	if (tunnel->version == L2TP_HDR_VER_3) {
+		pn = l2tp_pernet(tunnel->l2tp_net);
+		g_head = l2tp_session_id_hash_2(l2tp_pernet(tunnel->l2tp_net),
+						session->session_id);
+
+		spin_lock_bh(&pn->l2tp_session_hlist_lock);
+		hlist_for_each_entry(session_walk, g_head, global_hlist)
+			if (session_walk->session_id == session->session_id)
+				goto exist_glob;
+
+		hlist_add_head_rcu(&session->global_hlist, g_head);
+		spin_unlock_bh(&pn->l2tp_session_hlist_lock);
+	}
+
+	hlist_add_head(&session->hlist, head);
+	write_unlock_bh(&tunnel->hlist_lock);
+
+	return 0;
+
+exist_glob:
+	spin_unlock_bh(&pn->l2tp_session_hlist_lock);
+exist:
+	write_unlock_bh(&tunnel->hlist_lock);
+
+	return -EEXIST;
+}
 
 /* Lookup a tunnel by id
  */
@@ -633,6 +729,9 @@ static int l2tp_recv_data_seq(struct l2tp_session *session, struct sk_buff *skb)
  * a data (not control) frame before coming here. Fields up to the
  * session-id have already been parsed and ptr points to the data
  * after the session-id.
+ *
+ * session->ref() must have been called prior to l2tp_recv_common().
+ * session->deref() will be called automatically after skb is processed.
  */
 void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
 		      unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -642,14 +741,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
 	int offset;
 	u32 ns, nr;
 
-	/* The ref count is increased since we now hold a pointer to
-	 * the session. Take care to decrement the refcnt when exiting
-	 * this function from now on...
-	 */
-	l2tp_session_inc_refcount(session);
-	if (session->ref)
-		(*session->ref)(session);
-
 	/* Parse and check optional cookie */
 	if (session->peer_cookie_len > 0) {
 		if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
@@ -802,8 +893,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
 	/* Try to dequeue as many skbs from reorder_q as we can. */
 	l2tp_recv_dequeue(session);
 
-	l2tp_session_dec_refcount(session);
-
 	return;
 
 discard:
@@ -812,8 +901,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
 
 	if (session->deref)
 		(*session->deref)(session);
-
-	l2tp_session_dec_refcount(session);
 }
 EXPORT_SYMBOL(l2tp_recv_common);
 
@@ -920,8 +1007,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
 	}
 
 	/* Find the session context */
-	session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
+	session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
 	if (!session || !session->recv_skb) {
+		if (session) {
+			if (session->deref)
+				session->deref(session);
+			l2tp_session_dec_refcount(session);
+		}
+
 		/* Not found? Pass to userspace to deal with */
 		l2tp_info(tunnel, L2TP_MSG_DATA,
 			  "%s: no session found (%u/%u). Passing up.\n",
@@ -930,6 +1023,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
 	}
 
 	l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
+	l2tp_session_dec_refcount(session);
 
 	return 0;
 
@@ -1738,6 +1832,7 @@ EXPORT_SYMBOL_GPL(l2tp_session_set_header_len);
 struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg)
 {
 	struct l2tp_session *session;
+	int err;
 
 	session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL);
 	if (session != NULL) {
@@ -1793,6 +1888,13 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
 
 		l2tp_session_set_header_len(session, tunnel->version);
 
+		err = l2tp_session_add_to_tunnel(tunnel, session);
+		if (err) {
+			kfree(session);
+
+			return ERR_PTR(err);
+		}
+
 		/* Bump the reference count. The session context is deleted
 		 * only when this drops to zero.
 		 */
@@ -1802,28 +1904,14 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
 		/* Ensure tunnel socket isn't deleted */
 		sock_hold(tunnel->sock);
 
-		/* Add session to the tunnel's hash list */
-		write_lock_bh(&tunnel->hlist_lock);
-		hlist_add_head(&session->hlist,
-			       l2tp_session_id_hash(tunnel, session_id));
-		write_unlock_bh(&tunnel->hlist_lock);
-
-		/* And to the global session list if L2TPv3 */
-		if (tunnel->version != L2TP_HDR_VER_2) {
-			struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);
-
-			spin_lock_bh(&pn->l2tp_session_hlist_lock);
-			hlist_add_head_rcu(&session->global_hlist,
-					   l2tp_session_id_hash_2(pn, session_id));
-			spin_unlock_bh(&pn->l2tp_session_hlist_lock);
-		}
-
 		/* Ignore management session in session count value */
 		if (session->session_id != 0)
 			atomic_inc(&l2tp_session_count);
+
+		return session;
 	}
 
-	return session;
+	return ERR_PTR(-ENOMEM);
 }
 EXPORT_SYMBOL_GPL(l2tp_session_create);
 
diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h
index aebf281d09eeb..3b9b704a84e44 100644
--- a/net/l2tp/l2tp_core.h
+++ b/net/l2tp/l2tp_core.h
@@ -230,11 +230,15 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
 	return tunnel;
 }
 
+struct l2tp_session *l2tp_session_get(struct net *net,
+				      struct l2tp_tunnel *tunnel,
+				      u32 session_id, bool do_ref);
 struct l2tp_session *l2tp_session_find(struct net *net,
 				       struct l2tp_tunnel *tunnel,
 				       u32 session_id);
 struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth);
-struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname);
+struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
+						bool do_ref);
 struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id);
 struct l2tp_tunnel *l2tp_tunnel_find_nth(struct net *net, int nth);
 
diff --git a/net/l2tp/l2tp_eth.c b/net/l2tp/l2tp_eth.c
index 8bf18a5f66e0c..6fd41d7afe1ef 100644
--- a/net/l2tp/l2tp_eth.c
+++ b/net/l2tp/l2tp_eth.c
@@ -221,12 +221,6 @@ static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 p
 		goto out;
 	}
 
-	session = l2tp_session_find(net, tunnel, session_id);
-	if (session) {
-		rc = -EEXIST;
-		goto out;
-	}
-
 	if (cfg->ifname) {
 		dev = dev_get_by_name(net, cfg->ifname);
 		if (dev) {
@@ -240,8 +234,8 @@ static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 p
 
 	session = l2tp_session_create(sizeof(*spriv), tunnel, session_id,
 				      peer_session_id, cfg);
-	if (!session) {
-		rc = -ENOMEM;
+	if (IS_ERR(session)) {
+		rc = PTR_ERR(session);
 		goto out;
 	}
 
diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c
index 7208fbe5856bb..4d322c1b7233e 100644
--- a/net/l2tp/l2tp_ip.c
+++ b/net/l2tp/l2tp_ip.c
@@ -143,19 +143,19 @@ static int l2tp_ip_recv(struct sk_buff *skb)
 	}
 
 	/* Ok, this is a data packet. Lookup the session. */
-	session = l2tp_session_find(net, NULL, session_id);
-	if (session == NULL)
+	session = l2tp_session_get(net, NULL, session_id, true);
+	if (!session)
 		goto discard;
 
 	tunnel = session->tunnel;
-	if (tunnel == NULL)
-		goto discard;
+	if (!tunnel)
+		goto discard_sess;
 
 	/* Trace packet contents, if enabled */
 	if (tunnel->debug & L2TP_MSG_DATA) {
 		length = min(32u, skb->len);
 		if (!pskb_may_pull(skb, length))
-			goto discard;
+			goto discard_sess;
 
 		/* Point to L2TP header */
 		optr = ptr = skb->data;
@@ -165,6 +165,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
 	}
 
 	l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
+	l2tp_session_dec_refcount(session);
 
 	return 0;
 
@@ -203,6 +204,12 @@ static int l2tp_ip_recv(struct sk_buff *skb)
 
 	return sk_receive_skb(sk, skb, 1);
 
+discard_sess:
+	if (session->deref)
+		session->deref(session);
+	l2tp_session_dec_refcount(session);
+	goto discard;
+
 discard_put:
 	sock_put(sk);
 
diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c
index 516d7ce24ba7b..88b397c30d86a 100644
--- a/net/l2tp/l2tp_ip6.c
+++ b/net/l2tp/l2tp_ip6.c
@@ -156,19 +156,19 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
 	}
 
 	/* Ok, this is a data packet. Lookup the session. */
-	session = l2tp_session_find(net, NULL, session_id);
-	if (session == NULL)
+	session = l2tp_session_get(net, NULL, session_id, true);
+	if (!session)
 		goto discard;
 
 	tunnel = session->tunnel;
-	if (tunnel == NULL)
-		goto discard;
+	if (!tunnel)
+		goto discard_sess;
 
 	/* Trace packet contents, if enabled */
 	if (tunnel->debug & L2TP_MSG_DATA) {
 		length = min(32u, skb->len);
 		if (!pskb_may_pull(skb, length))
-			goto discard;
+			goto discard_sess;
 
 		/* Point to L2TP header */
 		optr = ptr = skb->data;
@@ -179,6 +179,8 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
 
 	l2tp_recv_common(session, skb, ptr, optr, 0, skb->len,
 			 tunnel->recv_payload_hook);
+	l2tp_session_dec_refcount(session);
+
 	return 0;
 
 pass_up:
@@ -216,6 +218,12 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
 
 	return sk_receive_skb(sk, skb, 1);
 
+discard_sess:
+	if (session->deref)
+		session->deref(session);
+	l2tp_session_dec_refcount(session);
+	goto discard;
+
 discard_put:
 	sock_put(sk);
 
diff --git a/net/l2tp/l2tp_netlink.c b/net/l2tp/l2tp_netlink.c
index 3620fba317863..93e317377c665 100644
--- a/net/l2tp/l2tp_netlink.c
+++ b/net/l2tp/l2tp_netlink.c
@@ -48,7 +48,8 @@ static int l2tp_nl_session_send(struct sk_buff *skb, u32 portid, u32 seq,
 /* Accessed under genl lock */
 static const struct l2tp_nl_cmd_ops *l2tp_nl_cmd_ops[__L2TP_PWTYPE_MAX];
 
-static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info)
+static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info,
+						bool do_ref)
 {
 	u32 tunnel_id;
 	u32 session_id;
@@ -59,14 +60,15 @@ static struct l2tp_session *l2tp_nl_session_find(struct genl_info *info)
 
 	if (info->attrs[L2TP_ATTR_IFNAME]) {
 		ifname = nla_data(info->attrs[L2TP_ATTR_IFNAME]);
-		session = l2tp_session_find_by_ifname(net, ifname);
+		session = l2tp_session_get_by_ifname(net, ifname, do_ref);
 	} else if ((info->attrs[L2TP_ATTR_SESSION_ID]) &&
 		   (info->attrs[L2TP_ATTR_CONN_ID])) {
 		tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
 		session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
 		tunnel = l2tp_tunnel_find(net, tunnel_id);
 		if (tunnel)
-			session = l2tp_session_find(net, tunnel, session_id);
+			session = l2tp_session_get(net, tunnel, session_id,
+						   do_ref);
 	}
 
 	return session;
@@ -642,10 +644,12 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
 			session_id, peer_session_id, &cfg);
 
 	if (ret >= 0) {
-		session = l2tp_session_find(net, tunnel, session_id);
-		if (session)
+		session = l2tp_session_get(net, tunnel, session_id, false);
+		if (session) {
 			ret = l2tp_session_notify(&l2tp_nl_family, info, session,
 						  L2TP_CMD_SESSION_CREATE);
+			l2tp_session_dec_refcount(session);
+		}
 	}
 
 out:
@@ -658,7 +662,7 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
 	struct l2tp_session *session;
 	u16 pw_type;
 
-	session = l2tp_nl_session_find(info);
+	session = l2tp_nl_session_get(info, true);
 	if (session == NULL) {
 		ret = -ENODEV;
 		goto out;
@@ -672,6 +676,10 @@ static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *inf
 		if (l2tp_nl_cmd_ops[pw_type] && l2tp_nl_cmd_ops[pw_type]->session_delete)
 			ret = (*l2tp_nl_cmd_ops[pw_type]->session_delete)(session);
 
+	if (session->deref)
+		session->deref(session);
+	l2tp_session_dec_refcount(session);
+
 out:
 	return ret;
 }
@@ -681,7 +689,7 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
 	int ret = 0;
 	struct l2tp_session *session;
 
-	session = l2tp_nl_session_find(info);
+	session = l2tp_nl_session_get(info, false);
 	if (session == NULL) {
 		ret = -ENODEV;
 		goto out;
@@ -716,6 +724,8 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
 	ret = l2tp_session_notify(&l2tp_nl_family, info,
 				  session, L2TP_CMD_SESSION_MODIFY);
 
+	l2tp_session_dec_refcount(session);
+
 out:
 	return ret;
 }
@@ -811,29 +821,34 @@ static int l2tp_nl_cmd_session_get(struct sk_buff *skb, struct genl_info *info)
 	struct sk_buff *msg;
 	int ret;
 
-	session = l2tp_nl_session_find(info);
+	session = l2tp_nl_session_get(info, false);
 	if (session == NULL) {
 		ret = -ENODEV;
-		goto out;
+		goto err;
 	}
 
 	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
 	if (!msg) {
 		ret = -ENOMEM;
-		goto out;
+		goto err_ref;
 	}
 
 	ret = l2tp_nl_session_send(msg, info->snd_portid, info->snd_seq,
 				   0, session, L2TP_CMD_SESSION_GET);
 	if (ret < 0)
-		goto err_out;
+		goto err_ref_msg;
 
-	return genlmsg_unicast(genl_info_net(info), msg, info->snd_portid);
+	ret = genlmsg_unicast(genl_info_net(info), msg, info->snd_portid);
 
-err_out:
-	nlmsg_free(msg);
+	l2tp_session_dec_refcount(session);
 
-out:
+	return ret;
+
+err_ref_msg:
+	nlmsg_free(msg);
+err_ref:
+	l2tp_session_dec_refcount(session);
+err:
 	return ret;
 }
 
diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c
index 123b6a2411a03..26501902d1a7f 100644
--- a/net/l2tp/l2tp_ppp.c
+++ b/net/l2tp/l2tp_ppp.c
@@ -583,6 +583,7 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 	int error = 0;
 	u32 tunnel_id, peer_tunnel_id;
 	u32 session_id, peer_session_id;
+	bool drop_refcnt = false;
 	int ver = 2;
 	int fd;
 
@@ -684,36 +685,36 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 	if (tunnel->peer_tunnel_id == 0)
 		tunnel->peer_tunnel_id = peer_tunnel_id;
 
-	/* Create session if it doesn't already exist. We handle the
-	 * case where a session was previously created by the netlink
-	 * interface by checking that the session doesn't already have
-	 * a socket and its tunnel socket are what we expect. If any
-	 * of those checks fail, return EEXIST to the caller.
-	 */
-	session = l2tp_session_find(sock_net(sk), tunnel, session_id);
-	if (session == NULL) {
-		/* Default MTU must allow space for UDP/L2TP/PPP
-		 * headers.
+	session = l2tp_session_get(sock_net(sk), tunnel, session_id, false);
+	if (session) {
+		drop_refcnt = true;
+		ps = l2tp_session_priv(session);
+
+		/* Using a pre-existing session is fine as long as it hasn't
+		 * been connected yet.
 		 */
-		cfg.mtu = cfg.mru = 1500 - PPPOL2TP_HEADER_OVERHEAD;
+		if (ps->sock) {
+			error = -EEXIST;
+			goto end;
+		}
 
-		/* Allocate and initialize a new session context. */
-		session = l2tp_session_create(sizeof(struct pppol2tp_session),
-					      tunnel, session_id,
-					      peer_session_id, &cfg);
-		if (session == NULL) {
-			error = -ENOMEM;
+		/* consistency checks */
+		if (ps->tunnel_sock != tunnel->sock) {
+			error = -EEXIST;
 			goto end;
 		}
 	} else {
-		ps = l2tp_session_priv(session);
-		error = -EEXIST;
-		if (ps->sock != NULL)
-			goto end;
+		/* Default MTU must allow space for UDP/L2TP/PPP headers */
+		cfg.mtu = 1500 - PPPOL2TP_HEADER_OVERHEAD;
+		cfg.mru = cfg.mtu;
 
-		/* consistency checks */
-		if (ps->tunnel_sock != tunnel->sock)
+		session = l2tp_session_create(sizeof(struct pppol2tp_session),
+					      tunnel, session_id,
+					      peer_session_id, &cfg);
+		if (IS_ERR(session)) {
+			error = PTR_ERR(session);
 			goto end;
+		}
 	}
 
 	/* Associate session with its PPPoL2TP socket */
@@ -778,6 +779,8 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 		  session->name);
 
 end:
+	if (drop_refcnt)
+		l2tp_session_dec_refcount(session);
 	release_sock(sk);
 
 	return error;
@@ -805,12 +808,6 @@ static int pppol2tp_session_create(struct net *net, u32 tunnel_id, u32 session_i
 	if (tunnel->sock == NULL)
 		goto out;
 
-	/* Check that this session doesn't already exist */
-	error = -EEXIST;
-	session = l2tp_session_find(net, tunnel, session_id);
-	if (session != NULL)
-		goto out;
-
 	/* Default MTU values. */
 	if (cfg->mtu == 0)
 		cfg->mtu = 1500 - PPPOL2TP_HEADER_OVERHEAD;
@@ -818,12 +815,13 @@ static int pppol2tp_session_create(struct net *net, u32 tunnel_id, u32 session_i
 		cfg->mru = cfg->mtu;
 
 	/* Allocate and initialize a new session context. */
-	error = -ENOMEM;
 	session = l2tp_session_create(sizeof(struct pppol2tp_session),
 				      tunnel, session_id,
 				      peer_session_id, cfg);
-	if (session == NULL)
+	if (IS_ERR(session)) {
+		error = PTR_ERR(session);
 		goto out;
+	}
 
 	ps = l2tp_session_priv(session);
 	ps->tunnel_sock = tunnel->sock;
@@ -1141,11 +1139,18 @@ static int pppol2tp_tunnel_ioctl(struct l2tp_tunnel *tunnel,
 		if (stats.session_id != 0) {
 			/* resend to session ioctl handler */
 			struct l2tp_session *session =
-				l2tp_session_find(sock_net(sk), tunnel, stats.session_id);
-			if (session != NULL)
-				err = pppol2tp_session_ioctl(session, cmd, arg);
-			else
+				l2tp_session_get(sock_net(sk), tunnel,
+						 stats.session_id, true);
+
+			if (session) {
+				err = pppol2tp_session_ioctl(session, cmd,
+							     arg);
+				if (session->deref)
+					session->deref(session);
+				l2tp_session_dec_refcount(session);
+			} else {
 				err = -EBADR;
+			}
 			break;
 		}
 #ifdef CONFIG_XFRM