tools/tcprtt: Allow to filter on IPv6 addresses

Signed-off-by: Jerome Marchand <jmarchan@redhat.com>
diff --git a/tools/tcprtt.py b/tools/tcprtt.py
index 847501b..d6f81a6 100755
--- a/tools/tcprtt.py
+++ b/tools/tcprtt.py
@@ -15,7 +15,7 @@
 from __future__ import print_function
 from bcc import BPF
 from time import sleep, strftime
-from socket import inet_ntop, AF_INET
+from socket import inet_ntop, inet_pton, AF_INET, AF_INET6
 import socket, struct
 import argparse
 import ctypes
@@ -110,6 +110,8 @@
     u16 dport = 0;
     u32 saddr = 0;
     u32 daddr = 0;
+    __u8 saddr6[16];
+    __u8 daddr6[16];
     u16 family = 0;
 
     /* for histogram */
@@ -120,9 +122,16 @@
 
     bpf_probe_read_kernel(&sport, sizeof(sport), (void *)&inet->inet_sport);
     bpf_probe_read_kernel(&dport, sizeof(dport), (void *)&inet->inet_dport);
-    bpf_probe_read_kernel(&saddr, sizeof(saddr), (void *)&inet->inet_saddr);
-    bpf_probe_read_kernel(&daddr, sizeof(daddr), (void *)&inet->inet_daddr);
     bpf_probe_read_kernel(&family, sizeof(family), (void *)&sk->__sk_common.skc_family);
+    if (family == AF_INET6) {
+        bpf_probe_read_kernel(&saddr6, sizeof(saddr6),
+                              (void *)&sk->__sk_common.skc_v6_rcv_saddr.s6_addr);
+        bpf_probe_read_kernel(&daddr6, sizeof(daddr6),
+                              (void *)&sk->__sk_common.skc_v6_daddr.s6_addr);
+    } else {
+        bpf_probe_read_kernel(&saddr, sizeof(saddr), (void *)&inet->inet_saddr);
+        bpf_probe_read_kernel(&daddr, sizeof(daddr), (void *)&inet->inet_daddr);
+    }
 
     LPORTFILTER
     RPORTFILTER
@@ -158,19 +167,26 @@
 else:
     bpf_text = bpf_text.replace('RPORTFILTER', '')
 
+def addrfilter(addr, src_or_dest):
+    try:
+        naddr = socket.inet_pton(AF_INET, addr)
+    except:
+        naddr = socket.inet_pton(AF_INET6, addr)
+        s = ('\\' + struct.unpack("=16s", naddr)[0].hex('\\')).replace('\\', '\\x')
+        filter = "if (memcmp(%s6, \"%s\", 16)) return 0;" % (src_or_dest, s)
+    else:
+        filter = "if (%s != %d) return 0;" % (src_or_dest, struct.unpack("=I", naddr)[0])
+    return filter
+
 # filter for local address
 if args.laddr:
-    bpf_text = bpf_text.replace('LADDRFILTER',
-        """if (saddr != %d)
-        return 0;""" % struct.unpack("=I", socket.inet_aton(args.laddr))[0])
+    bpf_text = bpf_text.replace('LADDRFILTER', addrfilter(args.laddr, 'saddr'))
 else:
     bpf_text = bpf_text.replace('LADDRFILTER', '')
 
 # filter for remote address
 if args.raddr:
-    bpf_text = bpf_text.replace('RADDRFILTER',
-        """if (daddr != %d)
-        return 0;""" % struct.unpack("=I", socket.inet_aton(args.raddr))[0])
+    bpf_text = bpf_text.replace('RADDRFILTER', addrfilter(args.raddr, 'daddr'))
 else:
     bpf_text = bpf_text.replace('RADDRFILTER', '')
 if args.ipv4: