Skip to content

Commit

Permalink
Merge branch 'l2tp_session_find-fixes'
Browse files Browse the repository at this point in the history
Guillaume Nault says:

====================
l2tp: fix usage of l2tp_session_find()

l2tp_session_find() doesn't take a reference on the session returned to
its caller. Virtually all l2tp_session_find() users are racy, either
because the session can disappear from under them or because they take
a reference too late. This leads to bugs like 'use after free' or
failure to notice duplicate session creations.

In some cases, taking a reference on the session is not enough. The
special callbacks .ref() and .deref() also have to be called in cases
where the PPP pseudo-wire uses the socket associated with the session.
Therefore, when looking up a session, we also have to pass a flag
indicating if the .ref() callback has to be called.

In the future, we probably could drop the .ref() and .deref() callbacks
entirely by protecting the .sock field of struct pppol2tp_session with
RCU, thus allowing it to be freed and set to NULL even if the L2TP
session is still alive.
====================

Signed-off-by: David S. Miller <davem@davemloft.net>
  • Loading branch information
David S. Miller committed Apr 2, 2017
2 parents afe8996 + 2777e2a commit e5c1e51
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 101 deletions.
152 changes: 120 additions & 32 deletions net/l2tp/l2tp_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,55 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
}
EXPORT_SYMBOL_GPL(l2tp_session_find);

/* Like l2tp_session_find() but takes a reference on the returned session.
* Optionally calls session->ref() too if do_ref is true.
*/
struct l2tp_session *l2tp_session_get(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id, bool do_ref)
{
struct hlist_head *session_list;
struct l2tp_session *session;

if (!tunnel) {
struct l2tp_net *pn = l2tp_pernet(net);

session_list = l2tp_session_id_hash_2(pn, session_id);

rcu_read_lock_bh();
hlist_for_each_entry_rcu(session, session_list, global_hlist) {
if (session->session_id == session_id) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
rcu_read_unlock_bh();

return session;
}
}
rcu_read_unlock_bh();

return NULL;
}

session_list = l2tp_session_id_hash(tunnel, session_id);
read_lock_bh(&tunnel->hlist_lock);
hlist_for_each_entry(session, session_list, hlist) {
if (session->session_id == session_id) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
read_unlock_bh(&tunnel->hlist_lock);

return session;
}
}
read_unlock_bh(&tunnel->hlist_lock);

return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_session_get);

struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
{
int hash;
Expand All @@ -303,7 +352,8 @@ EXPORT_SYMBOL_GPL(l2tp_session_find_nth);
/* Lookup a session by interface name.
* This is very inefficient but is only used by management interfaces.
*/
struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
bool do_ref)
{
struct l2tp_net *pn = l2tp_pernet(net);
int hash;
Expand All @@ -313,7 +363,11 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
if (!strcmp(session->ifname, ifname)) {
l2tp_session_inc_refcount(session);
if (do_ref && session->ref)
session->ref(session);
rcu_read_unlock_bh();

return session;
}
}
Expand All @@ -323,7 +377,49 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)

return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_session_find_by_ifname);
EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);

static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
struct l2tp_session *session)
{
struct l2tp_session *session_walk;
struct hlist_head *g_head;
struct hlist_head *head;
struct l2tp_net *pn;

head = l2tp_session_id_hash(tunnel, session->session_id);

write_lock_bh(&tunnel->hlist_lock);
hlist_for_each_entry(session_walk, head, hlist)
if (session_walk->session_id == session->session_id)
goto exist;

if (tunnel->version == L2TP_HDR_VER_3) {
pn = l2tp_pernet(tunnel->l2tp_net);
g_head = l2tp_session_id_hash_2(l2tp_pernet(tunnel->l2tp_net),
session->session_id);

spin_lock_bh(&pn->l2tp_session_hlist_lock);
hlist_for_each_entry(session_walk, g_head, global_hlist)
if (session_walk->session_id == session->session_id)
goto exist_glob;

hlist_add_head_rcu(&session->global_hlist, g_head);
spin_unlock_bh(&pn->l2tp_session_hlist_lock);
}

hlist_add_head(&session->hlist, head);
write_unlock_bh(&tunnel->hlist_lock);

return 0;

exist_glob:
spin_unlock_bh(&pn->l2tp_session_hlist_lock);
exist:
write_unlock_bh(&tunnel->hlist_lock);

return -EEXIST;
}

/* Lookup a tunnel by id
*/
Expand Down Expand Up @@ -633,6 +729,9 @@ static int l2tp_recv_data_seq(struct l2tp_session *session, struct sk_buff *skb)
* a data (not control) frame before coming here. Fields up to the
* session-id have already been parsed and ptr points to the data
* after the session-id.
*
* session->ref() must have been called prior to l2tp_recv_common().
* session->deref() will be called automatically after skb is processed.
*/
void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
unsigned char *ptr, unsigned char *optr, u16 hdrflags,
Expand All @@ -642,14 +741,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
int offset;
u32 ns, nr;

/* The ref count is increased since we now hold a pointer to
* the session. Take care to decrement the refcnt when exiting
* this function from now on...
*/
l2tp_session_inc_refcount(session);
if (session->ref)
(*session->ref)(session);

/* Parse and check optional cookie */
if (session->peer_cookie_len > 0) {
if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
Expand Down Expand Up @@ -802,8 +893,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
/* Try to dequeue as many skbs from reorder_q as we can. */
l2tp_recv_dequeue(session);

l2tp_session_dec_refcount(session);

return;

discard:
Expand All @@ -812,8 +901,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,

if (session->deref)
(*session->deref)(session);

l2tp_session_dec_refcount(session);
}
EXPORT_SYMBOL(l2tp_recv_common);

Expand Down Expand Up @@ -920,8 +1007,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
}

/* Find the session context */
session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
if (!session || !session->recv_skb) {
if (session) {
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
}

/* Not found? Pass to userspace to deal with */
l2tp_info(tunnel, L2TP_MSG_DATA,
"%s: no session found (%u/%u). Passing up.\n",
Expand All @@ -930,6 +1023,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
}

l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
l2tp_session_dec_refcount(session);

return 0;

Expand Down Expand Up @@ -1738,6 +1832,7 @@ EXPORT_SYMBOL_GPL(l2tp_session_set_header_len);
struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg)
{
struct l2tp_session *session;
int err;

session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL);
if (session != NULL) {
Expand Down Expand Up @@ -1793,6 +1888,13 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn

l2tp_session_set_header_len(session, tunnel->version);

err = l2tp_session_add_to_tunnel(tunnel, session);
if (err) {
kfree(session);

return ERR_PTR(err);
}

/* Bump the reference count. The session context is deleted
* only when this drops to zero.
*/
Expand All @@ -1802,28 +1904,14 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
/* Ensure tunnel socket isn't deleted */
sock_hold(tunnel->sock);

/* Add session to the tunnel's hash list */
write_lock_bh(&tunnel->hlist_lock);
hlist_add_head(&session->hlist,
l2tp_session_id_hash(tunnel, session_id));
write_unlock_bh(&tunnel->hlist_lock);

/* And to the global session list if L2TPv3 */
if (tunnel->version != L2TP_HDR_VER_2) {
struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);

spin_lock_bh(&pn->l2tp_session_hlist_lock);
hlist_add_head_rcu(&session->global_hlist,
l2tp_session_id_hash_2(pn, session_id));
spin_unlock_bh(&pn->l2tp_session_hlist_lock);
}

/* Ignore management session in session count value */
if (session->session_id != 0)
atomic_inc(&l2tp_session_count);

return session;
}

return session;
return ERR_PTR(-ENOMEM);
}
EXPORT_SYMBOL_GPL(l2tp_session_create);

Expand Down
6 changes: 5 additions & 1 deletion net/l2tp/l2tp_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,15 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
return tunnel;
}

struct l2tp_session *l2tp_session_get(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id, bool do_ref);
struct l2tp_session *l2tp_session_find(struct net *net,
struct l2tp_tunnel *tunnel,
u32 session_id);
struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth);
struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname);
struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
bool do_ref);
struct l2tp_tunnel *l2tp_tunnel_find(struct net *net, u32 tunnel_id);
struct l2tp_tunnel *l2tp_tunnel_find_nth(struct net *net, int nth);

Expand Down
10 changes: 2 additions & 8 deletions net/l2tp/l2tp_eth.c
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,6 @@ static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 p
goto out;
}

session = l2tp_session_find(net, tunnel, session_id);
if (session) {
rc = -EEXIST;
goto out;
}

if (cfg->ifname) {
dev = dev_get_by_name(net, cfg->ifname);
if (dev) {
Expand All @@ -240,8 +234,8 @@ static int l2tp_eth_create(struct net *net, u32 tunnel_id, u32 session_id, u32 p

session = l2tp_session_create(sizeof(*spriv), tunnel, session_id,
peer_session_id, cfg);
if (!session) {
rc = -ENOMEM;
if (IS_ERR(session)) {
rc = PTR_ERR(session);
goto out;
}

Expand Down
17 changes: 12 additions & 5 deletions net/l2tp/l2tp_ip.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,19 @@ static int l2tp_ip_recv(struct sk_buff *skb)
}

/* Ok, this is a data packet. Lookup the session. */
session = l2tp_session_find(net, NULL, session_id);
if (session == NULL)
session = l2tp_session_get(net, NULL, session_id, true);
if (!session)
goto discard;

tunnel = session->tunnel;
if (tunnel == NULL)
goto discard;
if (!tunnel)
goto discard_sess;

/* Trace packet contents, if enabled */
if (tunnel->debug & L2TP_MSG_DATA) {
length = min(32u, skb->len);
if (!pskb_may_pull(skb, length))
goto discard;
goto discard_sess;

/* Point to L2TP header */
optr = ptr = skb->data;
Expand All @@ -165,6 +165,7 @@ static int l2tp_ip_recv(struct sk_buff *skb)
}

l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
l2tp_session_dec_refcount(session);

return 0;

Expand Down Expand Up @@ -203,6 +204,12 @@ static int l2tp_ip_recv(struct sk_buff *skb)

return sk_receive_skb(sk, skb, 1);

discard_sess:
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
goto discard;

discard_put:
sock_put(sk);

Expand Down
18 changes: 13 additions & 5 deletions net/l2tp/l2tp_ip6.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,19 @@ static int l2tp_ip6_recv(struct sk_buff *skb)
}

/* Ok, this is a data packet. Lookup the session. */
session = l2tp_session_find(net, NULL, session_id);
if (session == NULL)
session = l2tp_session_get(net, NULL, session_id, true);
if (!session)
goto discard;

tunnel = session->tunnel;
if (tunnel == NULL)
goto discard;
if (!tunnel)
goto discard_sess;

/* Trace packet contents, if enabled */
if (tunnel->debug & L2TP_MSG_DATA) {
length = min(32u, skb->len);
if (!pskb_may_pull(skb, length))
goto discard;
goto discard_sess;

/* Point to L2TP header */
optr = ptr = skb->data;
Expand All @@ -179,6 +179,8 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

l2tp_recv_common(session, skb, ptr, optr, 0, skb->len,
tunnel->recv_payload_hook);
l2tp_session_dec_refcount(session);

return 0;

pass_up:
Expand Down Expand Up @@ -216,6 +218,12 @@ static int l2tp_ip6_recv(struct sk_buff *skb)

return sk_receive_skb(sk, skb, 1);

discard_sess:
if (session->deref)
session->deref(session);
l2tp_session_dec_refcount(session);
goto discard;

discard_put:
sock_put(sk);

Expand Down
Loading

0 comments on commit e5c1e51

Please sign in to comment.