Skip to content

Commit

Permalink
l2tp: fix race in l2tp_recv_common()
Browse files Browse the repository at this point in the history
Taking a reference on sessions in l2tp_recv_common() is racy; this
has to be done by the callers.

To this end, a new function is required (l2tp_session_get()) to
atomically lookup a session and take a reference on it. Callers then
have to manually drop this reference.

Fixes: fd558d1 ("l2tp: Split pppol2tp patch into separate l2tp and ppp parts")
Signed-off-by: Guillaume Nault <g.nault@alphalink.fr>
Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
Guillaume Nault authored and David S. Miller committed Apr 2, 2017
1 parent afe8996 commit 61b9a04
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 23 deletions.
73 changes: 60 additions & 13 deletions net/l2tp/l2tp_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,55 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
}
EXPORT_SYMBOL_GPL(l2tp_session_find);

/* Like l2tp_session_find() but takes a reference on the returned session.
* Optionally calls session->ref() too if do_ref is true.
*/
struct l2tp_session *l2tp_session_get(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id, bool do_ref)
{
struct hlist_head *session_list;
struct l2tp_session *session;

if (!tunnel) {
struct l2tp_net *pn = l2tp_pernet(net);

session_list = l2tp_session_id_hash_2(pn, session_id);

rcu_read_lock_bh();
hlist_for_each_entry_rcu(session, session_list, global_hlist) {
if (session->session_id == session_id) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
rcu_read_unlock_bh();

return session;
}
}
rcu_read_unlock_bh();

return NULL;
}

session_list = l2tp_session_id_hash(tunnel, session_id);
read_lock_bh(&tunnel->hlist_lock);
hlist_for_each_entry(session, session_list, hlist) {
if (session->session_id == session_id) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
read_unlock_bh(&tunnel->hlist_lock);

return session;
}
}
read_unlock_bh(&tunnel->hlist_lock);

return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_session_get);

struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
{
int hash;
Expand Down Expand Up @@ -633,6 +682,9 @@ static int l2tp_recv_data_seq(struct l2tp_session *session, struct sk_buff *skb)
* a data (not control) frame before coming here. Fields up to the
* session-id have already been parsed and ptr points to the data
* after the session-id.
*
* session->ref() must have been called prior to l2tp_recv_common().
* session->deref() will be called automatically after skb is processed.
*/
void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
unsigned char *ptr, unsigned char *optr, u16 hdrflags,
Expand All @@ -642,14 +694,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
int offset;
u32 ns, nr;

/* The ref count is increased since we now hold a pointer to
* the session. Take care to decrement the refcnt when exiting
* this function from now on...
*/
l2tp_session_inc_refcount(session);
if (session->ref)
(*session->ref)(session);

/* Parse and check optional cookie */
if (session->peer_cookie_len > 0) {
if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
Expand Down Expand Up @@ -802,8 +846,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
/* Try to dequeue as many skbs from reorder_q as we can. */
l2tp_recv_dequeue(session);

l2tp_session_dec_refcount(session);

return;

discard:
Expand All @@ -812,8 +854,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,

if (session->deref)
(*session->deref)(session);

l2tp_session_dec_refcount(session);
}
EXPORT_SYMBOL(l2tp_recv_common);

Expand Down Expand Up @@ -920,8 +960,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
}

/* Find the session context */
session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
if (!session || !session->recv_skb) {
if (session) {
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
}

/* Not found? Pass to userspace to deal with */
l2tp_info(tunnel, L2TP_MSG_DATA,
"%s: no session found (%u/%u). Passing up.\n",
Expand All @@ -930,6 +976,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
}

l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
l2tp_session_dec_refcount(session);

return 0;

Expand Down
3 changes: 3 additions & 0 deletions net/l2tp/l2tp_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
return tunnel;
}

struct l2tp_session *l2tp_session_get(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id, bool do_ref);
struct l2tp_session *l2tp_session_find(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id);
Expand Down
17 changes: 12 additions & 5 deletions net/l2tp/l2tp_ip.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,19 @@ static int l2tp_ip_recv(struct sk_buff *skb)
}

/* Ok, this is a data packet. Lookup the session. */
session = l2tp_session_find(net, NULL, session_id);
if (session == NULL)
session = l2tp_session_get(net, NULL, session_id, true);
if (!session)
goto discard;

tunnel = session->tunnel;
if (tunnel == NULL)
goto discard;
if (!tunnel)
goto discard_sess;

/* Trace packet contents, if enabled */
if (tunnel->debug & L2TP_MSG_DATA) {
length = min(32u, skb->len);
if (!pskb_may_pull(skb, length))
goto discard;
goto discard_sess;

/* Point to L2TP header */
optr = ptr = skb->data;
Expand All @@ -165,6 +165,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
}

l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
l2tp_session_dec_refcount(session);

return 0;

Expand Down Expand Up @@ -203,6 +204,12 @@ static int l2tp_ip_recv(struct sk_buff *skb)

return sk_receive_skb(sk, skb, 1);

discard_sess:
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
goto discard;

discard_put:
sock_put(sk);

Expand Down
18 changes: 13 additions & 5 deletions net/l2tp/l2tp_ip6.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,19 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
}

/* Ok, this is a data packet. Lookup the session. */
session = l2tp_session_find(net, NULL, session_id);
if (session == NULL)
session = l2tp_session_get(net, NULL, session_id, true);
if (!session)
goto discard;

tunnel = session->tunnel;
if (tunnel == NULL)
goto discard;
if (!tunnel)
goto discard_sess;

/* Trace packet contents, if enabled */
if (tunnel->debug & L2TP_MSG_DATA) {
length = min(32u, skb->len);
if (!pskb_may_pull(skb, length))
goto discard;
goto discard_sess;

/* Point to L2TP header */
optr = ptr = skb->data;
Expand All @@ -179,6 +179,8 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

l2tp_recv_common(session, skb, ptr, optr, 0, skb->len,
tunnel->recv_payload_hook);
l2tp_session_dec_refcount(session);

return 0;

pass_up:
Expand Down Expand Up @@ -216,6 +218,12 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

return sk_receive_skb(sk, skb, 1);

discard_sess:
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
goto discard;

discard_put:
sock_put(sk);

Expand Down

0 comments on commit 61b9a04

Please sign in to comment.