Skip to content

Commit

Permalink
Merge branch 'net-deduplicate-cookie-logic'
Browse files Browse the repository at this point in the history
Willem de Bruijn says:

====================
net: deduplicate cookie logic

Reuse standard sk, ip and ipv6 cookie init handlers where possible.

Avoid repeated open coding of the same logic.
Harmonize feature sets across protocols.
Make IPv4 and IPv6 logic more alike.
Simplify adding future new fields with a single init point.
====================

Link: https://patch.msgid.link/20250214222720.3205500-1-willemdebruijn.kernel@gmail.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
  • Loading branch information
Jakub Kicinski committed Feb 19, 2025
2 parents 3a03f9e + 5cd2f78 commit aefd232
Show file tree
Hide file tree
Showing 14 changed files with 30 additions and 71 deletions.
16 changes: 5 additions & 11 deletions include/net/ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ static inline void ipcm_init(struct ipcm_cookie *ipcm)
static inline void ipcm_init_sk(struct ipcm_cookie *ipcm,
const struct inet_sock *inet)
{
ipcm_init(ipcm);
*ipcm = (struct ipcm_cookie) {
.tos = READ_ONCE(inet->tos),
};

sockcm_init(&ipcm->sockc, &inet->sk);

ipcm->sockc.mark = READ_ONCE(inet->sk.sk_mark);
ipcm->sockc.priority = READ_ONCE(inet->sk.sk_priority);
ipcm->sockc.tsflags = READ_ONCE(inet->sk.sk_tsflags);
ipcm->oif = READ_ONCE(inet->sk.sk_bound_dev_if);
ipcm->addr = inet->inet_saddr;
ipcm->protocol = inet->inet_num;
Expand Down Expand Up @@ -257,13 +258,6 @@ static inline u8 ip_sendmsg_scope(const struct inet_sock *inet,
return RT_SCOPE_UNIVERSE;
}

static inline __u8 get_rttos(struct ipcm_cookie* ipc, struct inet_sock *inet)
{
u8 dsfield = ipc->tos != -1 ? ipc->tos : READ_ONCE(inet->tos);

return dsfield & INET_DSCP_MASK;
}

/* datagram.c */
int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
int ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
Expand Down
11 changes: 2 additions & 9 deletions include/net/ipv6.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,6 @@ struct ipcm6_cookie {
struct ipv6_txoptions *opt;
};

static inline void ipcm6_init(struct ipcm6_cookie *ipc6)
{
*ipc6 = (struct ipcm6_cookie) {
.hlimit = -1,
.tclass = -1,
.dontfrag = -1,
};
}

static inline void ipcm6_init_sk(struct ipcm6_cookie *ipc6,
const struct sock *sk)
{
Expand All @@ -380,6 +371,8 @@ static inline void ipcm6_init_sk(struct ipcm6_cookie *ipc6,
.tclass = inet6_sk(sk)->tclass,
.dontfrag = inet6_test_bit(DONTFRAG, sk),
};

sockcm_init(&ipc6->sockc, sk);
}

static inline struct ipv6_txoptions *txopt_get(const struct ipv6_pinfo *np)
Expand Down
1 change: 1 addition & 0 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,7 @@ static inline void sockcm_init(struct sockcm_cookie *sockc,
const struct sock *sk)
{
*sockc = (struct sockcm_cookie) {
.mark = READ_ONCE(sk->sk_mark),
.tsflags = READ_ONCE(sk->sk_tsflags),
.priority = READ_ONCE(sk->sk_priority),
};
Expand Down
2 changes: 1 addition & 1 deletion net/can/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)

skb->dev = dev;
skb->priority = sockc.priority;
skb->mark = READ_ONCE(sk->sk_mark);
skb->mark = sockc.mark;
skb->tstamp = sockc.transmit_time;

skb_setup_tx_timestamp(skb, &sockc);
Expand Down
6 changes: 2 additions & 4 deletions net/ipv4/icmp.c
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ static void icmp_reply(struct icmp_bxm *icmp_param, struct sk_buff *skb)
struct ipcm_cookie ipc;
struct flowi4 fl4;
struct sock *sk;
struct inet_sock *inet;
__be32 daddr, saddr;
u32 mark = IP4_REPLY_MARK(net, skb->mark);
int type = icmp_param->data.icmph.type;
Expand All @@ -424,12 +423,11 @@ static void icmp_reply(struct icmp_bxm *icmp_param, struct sk_buff *skb)
sk = icmp_xmit_lock(net);
if (!sk)
goto out_bh_enable;
inet = inet_sk(sk);

icmp_param->data.icmph.checksum = 0;

ipcm_init(&ipc);
inet->tos = ip_hdr(skb)->tos;
ipc.tos = ip_hdr(skb)->tos;
ipc.sockc.mark = mark;
daddr = ipc.addr = ip_hdr(skb)->saddr;
saddr = fib_compute_spec_dst(skb);
Expand Down Expand Up @@ -737,8 +735,8 @@ void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info,
icmp_param.data.icmph.checksum = 0;
icmp_param.skb = skb_in;
icmp_param.offset = skb_network_offset(skb_in);
inet_sk(sk)->tos = tos;
ipcm_init(&ipc);
ipc.tos = tos;
ipc.addr = iph->saddr;
ipc.opt = &icmp_param.replyopts.opt;
ipc.sockc.mark = mark;
Expand Down
6 changes: 3 additions & 3 deletions net/ipv4/ping.c
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
struct ip_options_data opt_copy;
int free = 0;
__be32 saddr, daddr, faddr;
u8 tos, scope;
u8 scope;
int err;

pr_debug("ping_v4_sendmsg(sk=%p,sk->num=%u)\n", inet, inet->inet_num);
Expand Down Expand Up @@ -768,7 +768,6 @@ static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
}
faddr = ipc.opt->opt.faddr;
}
tos = get_rttos(&ipc, inet);
scope = ip_sendmsg_scope(inet, &ipc, msg);

if (ipv4_is_multicast(daddr)) {
Expand All @@ -779,7 +778,8 @@ static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
} else if (!ipc.oif)
ipc.oif = READ_ONCE(inet->uc_index);

flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, tos, scope,
flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark,
ipc.tos & INET_DSCP_MASK, scope,
sk->sk_protocol, inet_sk_flowi_flags(sk), faddr,
saddr, 0, 0, sk->sk_uid);

Expand Down
6 changes: 3 additions & 3 deletions net/ipv4/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
struct ipcm_cookie ipc;
struct rtable *rt = NULL;
struct flowi4 fl4;
u8 tos, scope;
u8 scope;
int free = 0;
__be32 daddr;
__be32 saddr;
Expand Down Expand Up @@ -581,7 +581,6 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
daddr = ipc.opt->opt.faddr;
}
}
tos = get_rttos(&ipc, inet);
scope = ip_sendmsg_scope(inet, &ipc, msg);

uc_index = READ_ONCE(inet->uc_index);
Expand All @@ -606,7 +605,8 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
}
}

flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, tos, scope,
flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark,
ipc.tos & INET_DSCP_MASK, scope,
hdrincl ? ipc.protocol : sk->sk_protocol,
inet_sk_flowi_flags(sk) |
(hdrincl ? FLOWI_FLAG_KNOWN_NH : 0),
Expand Down
2 changes: 1 addition & 1 deletion net/ipv4/tcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
/* 'common' sending to sendq */
}

sockcm_init(&sockc, sk);
sockc = (struct sockcm_cookie) { .tsflags = READ_ONCE(sk->sk_tsflags)};
if (msg->msg_controllen) {
err = sock_cmsg_send(sk, msg, &sockc);
if (unlikely(err)) {
Expand Down
6 changes: 3 additions & 3 deletions net/ipv4/udp.c
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
int free = 0;
int connected = 0;
__be32 daddr, faddr, saddr;
u8 tos, scope;
u8 scope;
__be16 dport;
int err, is_udplite = IS_UDPLITE(sk);
int corkreq = udp_test_bit(CORK, sk) || msg->msg_flags & MSG_MORE;
Expand Down Expand Up @@ -1404,7 +1404,6 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
faddr = ipc.opt->opt.faddr;
connected = 0;
}
tos = get_rttos(&ipc, inet);
scope = ip_sendmsg_scope(inet, &ipc, msg);
if (scope == RT_SCOPE_LINK)
connected = 0;
Expand Down Expand Up @@ -1441,7 +1440,8 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)

fl4 = &fl4_stack;

flowi4_init_output(fl4, ipc.oif, ipc.sockc.mark, tos, scope,
flowi4_init_output(fl4, ipc.oif, ipc.sockc.mark,
ipc.tos & INET_DSCP_MASK, scope,
sk->sk_protocol, flow_flags, faddr, saddr,
dport, inet->inet_sport, sk->sk_uid);

Expand Down
3 changes: 0 additions & 3 deletions net/ipv6/ping.c
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@ static int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
return -EINVAL;

ipcm6_init_sk(&ipc6, sk);
ipc6.sockc.priority = READ_ONCE(sk->sk_priority);
ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags);
ipc6.sockc.mark = READ_ONCE(sk->sk_mark);

fl6.flowi6_oif = oif;

Expand Down
15 changes: 3 additions & 12 deletions net/ipv6/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -769,19 +769,16 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)

hdrincl = inet_test_bit(HDRINCL, sk);

ipcm6_init_sk(&ipc6, sk);

/*
* Get and verify the address.
*/
memset(&fl6, 0, sizeof(fl6));

fl6.flowi6_mark = READ_ONCE(sk->sk_mark);
fl6.flowi6_mark = ipc6.sockc.mark;
fl6.flowi6_uid = sk->sk_uid;

ipcm6_init(&ipc6);
ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags);
ipc6.sockc.mark = fl6.flowi6_mark;
ipc6.sockc.priority = READ_ONCE(sk->sk_priority);

if (sin6) {
if (addr_len < SIN6_LEN_RFC2133)
return -EINVAL;
Expand Down Expand Up @@ -891,9 +888,6 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
if (hdrincl)
fl6.flowi6_flags |= FLOWI_FLAG_KNOWN_NH;

if (ipc6.tclass < 0)
ipc6.tclass = np->tclass;

fl6.flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6.flowlabel);

dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p);
Expand All @@ -904,9 +898,6 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
if (ipc6.hlimit < 0)
ipc6.hlimit = ip6_sk_dst_hoplimit(np, &fl6, dst);

if (ipc6.dontfrag < 0)
ipc6.dontfrag = inet6_test_bit(DONTFRAG, sk);

if (msg->msg_flags&MSG_CONFIRM)
goto do_confirm;

Expand Down
10 changes: 1 addition & 9 deletions net/ipv6/udp.c
Original file line number Diff line number Diff line change
Expand Up @@ -1494,11 +1494,8 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
int is_udplite = IS_UDPLITE(sk);
int (*getfrag)(void *, char *, int, int, int, struct sk_buff *);

ipcm6_init(&ipc6);
ipcm6_init_sk(&ipc6, sk);
ipc6.gso_size = READ_ONCE(up->gso_size);
ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags);
ipc6.sockc.mark = READ_ONCE(sk->sk_mark);
ipc6.sockc.priority = READ_ONCE(sk->sk_priority);

/* destination address check */
if (sin6) {
Expand Down Expand Up @@ -1704,9 +1701,6 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)

security_sk_classify_flow(sk, flowi6_to_flowi_common(fl6));

if (ipc6.tclass < 0)
ipc6.tclass = np->tclass;

fl6->flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6->flowlabel);

dst = ip6_sk_dst_lookup_flow(sk, fl6, final_p, connected);
Expand Down Expand Up @@ -1752,8 +1746,6 @@ int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
WRITE_ONCE(up->pending, AF_INET6);

do_append_data:
if (ipc6.dontfrag < 0)
ipc6.dontfrag = inet6_test_bit(DONTFRAG, sk);
up->len += ulen;
err = ip6_append_data(sk, getfrag, msg, ulen, sizeof(struct udphdr),
&ipc6, fl6, dst_rt6_info(dst),
Expand Down
8 changes: 1 addition & 7 deletions net/l2tp/l2tp_ip6.c
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
fl6.flowi6_mark = READ_ONCE(sk->sk_mark);
fl6.flowi6_uid = sk->sk_uid;

ipcm6_init(&ipc6);
ipcm6_init_sk(&ipc6, sk);

if (lsa) {
if (addr_len < SIN6_LEN_RFC2133)
Expand Down Expand Up @@ -634,9 +634,6 @@ static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)

security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6));

if (ipc6.tclass < 0)
ipc6.tclass = np->tclass;

fl6.flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6.flowlabel);

dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p);
Expand All @@ -648,9 +645,6 @@ static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
if (ipc6.hlimit < 0)
ipc6.hlimit = ip6_sk_dst_hoplimit(np, &fl6, dst);

if (ipc6.dontfrag < 0)
ipc6.dontfrag = inet6_test_bit(DONTFRAG, sk);

if (msg->msg_flags & MSG_CONFIRM)
goto do_confirm;

Expand Down
9 changes: 4 additions & 5 deletions net/packet/af_packet.c
Original file line number Diff line number Diff line change
Expand Up @@ -2102,8 +2102,8 @@ static int packet_sendmsg_spkt(struct socket *sock, struct msghdr *msg,

skb->protocol = proto;
skb->dev = dev;
skb->priority = READ_ONCE(sk->sk_priority);
skb->mark = READ_ONCE(sk->sk_mark);
skb->priority = sockc.priority;
skb->mark = sockc.mark;
skb_set_delivery_type_by_clockid(skb, sockc.transmit_time, sk->sk_clockid);
skb_setup_tx_timestamp(skb, &sockc);

Expand Down Expand Up @@ -2634,8 +2634,8 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,

skb->protocol = proto;
skb->dev = dev;
skb->priority = READ_ONCE(po->sk.sk_priority);
skb->mark = READ_ONCE(po->sk.sk_mark);
skb->priority = sockc->priority;
skb->mark = sockc->mark;
skb_set_delivery_type_by_clockid(skb, sockc->transmit_time, po->sk.sk_clockid);
skb_setup_tx_timestamp(skb, sockc);
skb_zcopy_set_nouarg(skb, ph.raw);
Expand Down Expand Up @@ -3039,7 +3039,6 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
goto out_unlock;

sockcm_init(&sockc, sk);
sockc.mark = READ_ONCE(sk->sk_mark);
if (msg->msg_controllen) {
err = sock_cmsg_send(sk, msg, &sockc);
if (unlikely(err))
Expand Down

0 comments on commit aefd232

Please sign in to comment.