Skip to content
Snippets Groups Projects
Commit a8852b24 authored by Maciej Żenczykowski's avatar Maciej Żenczykowski
Browse files

type safety for 'bool egress'


Test: TreeHugger
Signed-off-by: default avatarMaciej Żenczykowski <maze@google.com>
Change-Id: I3f0a12f139478bc94d351d58a08d4a9bd19fa320
parent f0608499
No related branches found
No related tags found
No related merge requests found
...@@ -87,9 +87,9 @@ static inline __always_inline void try_make_writable(struct __sk_buff* skb, int ...@@ -87,9 +87,9 @@ static inline __always_inline void try_make_writable(struct __sk_buff* skb, int
if (skb->data_end - skb->data < len) bpf_skb_pull_data(skb, len); if (skb->data_end - skb->data < len) bpf_skb_pull_data(skb, len);
} }
// constants for passing in to 'bool egress' struct egress_bool { bool egress; };
static const bool INGRESS = false; #define INGRESS ((struct egress_bool){ .egress = false })
static const bool EGRESS = true; #define EGRESS ((struct egress_bool){ .egress = true })
// constants for passing in to 'bool downstream' // constants for passing in to 'bool downstream'
static const bool UPSTREAM = false; static const bool UPSTREAM = false;
......
...@@ -178,7 +178,7 @@ static __always_inline int is_system_uid(uint32_t uid) { ...@@ -178,7 +178,7 @@ static __always_inline int is_system_uid(uint32_t uid) {
#define DEFINE_UPDATE_STATS(the_stats_map, TypeOfKey) \ #define DEFINE_UPDATE_STATS(the_stats_map, TypeOfKey) \
static __always_inline inline void update_##the_stats_map(const struct __sk_buff* const skb, \ static __always_inline inline void update_##the_stats_map(const struct __sk_buff* const skb, \
const TypeOfKey* const key, \ const TypeOfKey* const key, \
const bool egress, \ const struct egress_bool egress, \
const struct kver_uint kver) { \ const struct kver_uint kver) { \
StatsValue* value = bpf_##the_stats_map##_lookup_elem(key); \ StatsValue* value = bpf_##the_stats_map##_lookup_elem(key); \
if (!value) { \ if (!value) { \
...@@ -199,7 +199,7 @@ static __always_inline int is_system_uid(uint32_t uid) { ...@@ -199,7 +199,7 @@ static __always_inline int is_system_uid(uint32_t uid) {
packets = (payload + mss - 1) / mss; \ packets = (payload + mss - 1) / mss; \
bytes = tcp_overhead * packets + payload; \ bytes = tcp_overhead * packets + payload; \
} \ } \
if (egress) { \ if (egress.egress) { \
__sync_fetch_and_add(&value->txPackets, packets); \ __sync_fetch_and_add(&value->txPackets, packets); \
__sync_fetch_and_add(&value->txBytes, bytes); \ __sync_fetch_and_add(&value->txBytes, bytes); \
} else { \ } else { \
...@@ -242,7 +242,7 @@ static __always_inline inline int bpf_skb_load_bytes_net(const struct __sk_buff* ...@@ -242,7 +242,7 @@ static __always_inline inline int bpf_skb_load_bytes_net(const struct __sk_buff*
} }
static __always_inline inline void do_packet_tracing( static __always_inline inline void do_packet_tracing(
const struct __sk_buff* const skb, const bool egress, const uint32_t uid, const struct __sk_buff* const skb, const struct egress_bool egress, const uint32_t uid,
const uint32_t tag, const bool enable_tracing, const struct kver_uint kver) { const uint32_t tag, const bool enable_tracing, const struct kver_uint kver) {
if (!enable_tracing) return; if (!enable_tracing) return;
if (!KVER_IS_AT_LEAST(kver, 5, 8, 0)) return; if (!KVER_IS_AT_LEAST(kver, 5, 8, 0)) return;
...@@ -317,8 +317,8 @@ static __always_inline inline void do_packet_tracing( ...@@ -317,8 +317,8 @@ static __always_inline inline void do_packet_tracing(
pkt->sport = sport; pkt->sport = sport;
pkt->dport = dport; pkt->dport = dport;
pkt->egress = egress; pkt->egress = egress.egress;
pkt->wakeup = !egress && (skb->mark & 0x80000000); // Fwmark.ingress_cpu_wakeup pkt->wakeup = !egress.egress && (skb->mark & 0x80000000); // Fwmark.ingress_cpu_wakeup
pkt->ipProto = proto; pkt->ipProto = proto;
pkt->tcpFlags = flags; pkt->tcpFlags = flags;
pkt->ipVersion = ipVersion; pkt->ipVersion = ipVersion;
...@@ -326,7 +326,8 @@ static __always_inline inline void do_packet_tracing( ...@@ -326,7 +326,8 @@ static __always_inline inline void do_packet_tracing(
bpf_packet_trace_ringbuf_submit(pkt); bpf_packet_trace_ringbuf_submit(pkt);
} }
static __always_inline inline bool skip_owner_match(struct __sk_buff* skb, bool egress, static __always_inline inline bool skip_owner_match(struct __sk_buff* skb,
const struct egress_bool egress,
const struct kver_uint kver) { const struct kver_uint kver) {
uint32_t flag = 0; uint32_t flag = 0;
if (skb->protocol == htons(ETH_P_IP)) { if (skb->protocol == htons(ETH_P_IP)) {
...@@ -358,7 +359,7 @@ static __always_inline inline bool skip_owner_match(struct __sk_buff* skb, bool ...@@ -358,7 +359,7 @@ static __always_inline inline bool skip_owner_match(struct __sk_buff* skb, bool
return false; return false;
} }
// Always allow RST's, and additionally allow ingress FINs // Always allow RST's, and additionally allow ingress FINs
return flag & (TCP_FLAG_RST | (egress ? 0 : TCP_FLAG_FIN)); // false on read failure return flag & (TCP_FLAG_RST | (egress.egress ? 0 : TCP_FLAG_FIN)); // false on read failure
} }
static __always_inline inline BpfConfig getConfig(uint32_t configKey) { static __always_inline inline BpfConfig getConfig(uint32_t configKey) {
...@@ -401,7 +402,8 @@ static __always_inline inline bool ingress_should_discard(struct __sk_buff* skb, ...@@ -401,7 +402,8 @@ static __always_inline inline bool ingress_should_discard(struct __sk_buff* skb,
} }
static __always_inline inline int bpf_owner_match(struct __sk_buff* skb, uint32_t uid, static __always_inline inline int bpf_owner_match(struct __sk_buff* skb, uint32_t uid,
bool egress, const struct kver_uint kver) { const struct egress_bool egress,
const struct kver_uint kver) {
if (is_system_uid(uid)) return PASS; if (is_system_uid(uid)) return PASS;
if (skip_owner_match(skb, egress, kver)) return PASS; if (skip_owner_match(skb, egress, kver)) return PASS;
...@@ -414,7 +416,7 @@ static __always_inline inline int bpf_owner_match(struct __sk_buff* skb, uint32_ ...@@ -414,7 +416,7 @@ static __always_inline inline int bpf_owner_match(struct __sk_buff* skb, uint32_
if (isBlockedByUidRules(enabledRules, uidRules)) return DROP; if (isBlockedByUidRules(enabledRules, uidRules)) return DROP;
if (!egress && skb->ifindex != 1) { if (!egress.egress && skb->ifindex != 1) {
if (ingress_should_discard(skb, kver)) return DROP; if (ingress_should_discard(skb, kver)) return DROP;
if (uidRules & IIF_MATCH) { if (uidRules & IIF_MATCH) {
if (allowed_iif && skb->ifindex != allowed_iif) { if (allowed_iif && skb->ifindex != allowed_iif) {
...@@ -434,7 +436,7 @@ static __always_inline inline int bpf_owner_match(struct __sk_buff* skb, uint32_ ...@@ -434,7 +436,7 @@ static __always_inline inline int bpf_owner_match(struct __sk_buff* skb, uint32_
static __always_inline inline void update_stats_with_config(const uint32_t selectedMap, static __always_inline inline void update_stats_with_config(const uint32_t selectedMap,
const struct __sk_buff* const skb, const struct __sk_buff* const skb,
const StatsKey* const key, const StatsKey* const key,
const bool egress, const struct egress_bool egress,
const struct kver_uint kver) { const struct kver_uint kver) {
if (selectedMap == SELECT_MAP_A) { if (selectedMap == SELECT_MAP_A) {
update_stats_map_A(skb, key, egress, kver); update_stats_map_A(skb, key, egress, kver);
...@@ -443,7 +445,8 @@ static __always_inline inline void update_stats_with_config(const uint32_t selec ...@@ -443,7 +445,8 @@ static __always_inline inline void update_stats_with_config(const uint32_t selec
} }
} }
static __always_inline inline int bpf_traffic_account(struct __sk_buff* skb, bool egress, static __always_inline inline int bpf_traffic_account(struct __sk_buff* skb,
const struct egress_bool egress,
const bool enable_tracing, const bool enable_tracing,
const struct kver_uint kver) { const struct kver_uint kver) {
uint32_t sock_uid = bpf_get_socket_uid(skb); uint32_t sock_uid = bpf_get_socket_uid(skb);
...@@ -462,7 +465,7 @@ static __always_inline inline int bpf_traffic_account(struct __sk_buff* skb, boo ...@@ -462,7 +465,7 @@ static __always_inline inline int bpf_traffic_account(struct __sk_buff* skb, boo
// interface is accounted for and subject to usage restrictions. // interface is accounted for and subject to usage restrictions.
// CLAT IPv6 TX sockets are *always* tagged with CLAT uid, see tagSocketAsClat() // CLAT IPv6 TX sockets are *always* tagged with CLAT uid, see tagSocketAsClat()
// CLAT daemon receives via an untagged AF_PACKET socket. // CLAT daemon receives via an untagged AF_PACKET socket.
if (egress && uid == AID_CLAT) return PASS; if (egress.egress && uid == AID_CLAT) return PASS;
int match = bpf_owner_match(skb, sock_uid, egress, kver); int match = bpf_owner_match(skb, sock_uid, egress, kver);
...@@ -478,7 +481,7 @@ static __always_inline inline int bpf_traffic_account(struct __sk_buff* skb, boo ...@@ -478,7 +481,7 @@ static __always_inline inline int bpf_traffic_account(struct __sk_buff* skb, boo
} }
// If an outbound packet is going to be dropped, we do not count that traffic. // If an outbound packet is going to be dropped, we do not count that traffic.
if (egress && (match == DROP)) return DROP; if (egress.egress && (match == DROP)) return DROP;
StatsKey key = {.uid = uid, .tag = tag, .counterSet = 0, .ifaceIndex = skb->ifindex}; StatsKey key = {.uid = uid, .tag = tag, .counterSet = 0, .ifaceIndex = skb->ifindex};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment