Merge "cts: android.kernel.config"
diff --git a/ANRdaemon/Android.mk b/ANRdaemon/Android.mk
index 51bebc5..535eb8c 100644
--- a/ANRdaemon/Android.mk
+++ b/ANRdaemon/Android.mk
@@ -10,6 +10,7 @@
LOCAL_MODULE_TAGS:= optional
LOCAL_SHARED_LIBRARIES := \
+ liblog \
libbinder \
libcutils \
libutils \
diff --git a/ext4_utils/ext4fixup.c b/ext4_utils/ext4fixup.c
index 184cd0d..4b40207 100644
--- a/ext4_utils/ext4fixup.c
+++ b/ext4_utils/ext4fixup.c
@@ -806,6 +806,7 @@
}
close(fd);
+ free(dirbuf);
return 0;
}
diff --git a/f2fs_utils/Android.mk b/f2fs_utils/Android.mk
index 647c390..82c3ee0 100644
--- a/f2fs_utils/Android.mk
+++ b/f2fs_utils/Android.mk
@@ -76,7 +76,7 @@
include $(CLEAR_VARS)
LOCAL_MODULE := libf2fs_sparseblock
LOCAL_SRC_FILES := f2fs_sparseblock.c
-LOCAL_SHARED_LIBRARIES := libcutils
+LOCAL_SHARED_LIBRARIES := liblog libcutils
LOCAL_C_INCLUDES := external/f2fs-tools/include \
system/core/include/log
include $(BUILD_SHARED_LIBRARY)
@@ -84,7 +84,7 @@
include $(CLEAR_VARS)
LOCAL_MODULE := f2fs_sparseblock
LOCAL_SRC_FILES := f2fs_sparseblock.c
-LOCAL_SHARED_LIBRARIES := libcutils
+LOCAL_SHARED_LIBRARIES := liblog libcutils
LOCAL_C_INCLUDES := external/f2fs-tools/include \
system/core/include/log
include $(BUILD_EXECUTABLE)
diff --git a/libfec/fec_read.cpp b/libfec/fec_read.cpp
index 2d29da8..0f5ec99 100644
--- a/libfec/fec_read.cpp
+++ b/libfec/fec_read.cpp
@@ -47,7 +47,9 @@
for (size_t m = 0; m < bytes_per_line; ++m) {
if (n + m < size) {
- sprintf(&hex[m * 3], "%02x ", data[n + m]);
+ ptrdiff_t offset = &hex[m * 3] - hex;
+ snprintf(hex + offset, sizeof(hex) - offset, "%02x ",
+ data[n + m]);
if (isprint(data[n + m])) {
prn[m] = data[n + m];
diff --git a/libpagemap/include/pagemap/pagemap.h b/libpagemap/include/pagemap/pagemap.h
index 9063b1e..4de2b4b 100644
--- a/libpagemap/include/pagemap/pagemap.h
+++ b/libpagemap/include/pagemap/pagemap.h
@@ -21,9 +21,19 @@
#include <stdio.h>
#include <sys/cdefs.h>
#include <sys/types.h>
+#include <sys/queue.h>
__BEGIN_DECLS
+typedef struct pm_proportional_swap pm_proportional_swap_t;
+
+typedef struct pm_swap_offset pm_swap_offset_t;
+
+struct pm_swap_offset {
+ unsigned int offset;
+ SIMPLEQ_ENTRY(pm_swap_offset) simpleqe;
+};
+
typedef struct pm_memusage pm_memusage_t;
/* Holds the various metrics for memory usage of a process or a mapping. */
@@ -33,12 +43,32 @@
size_t pss;
size_t uss;
size_t swap;
+ /* if non NULL then use swap_offset_list to compute proportional swap */
+ pm_proportional_swap_t *p_swap;
+ SIMPLEQ_HEAD(simpleqhead, pm_swap_offset) swap_offset_list;
+};
+
+typedef struct pm_swapusage pm_swapusage_t;
+struct pm_swapusage {
+ size_t proportional;
+ size_t unique;
};
/* Clears a memusage. */
void pm_memusage_zero(pm_memusage_t *mu);
/* Adds one memusage (a) to another (b). */
void pm_memusage_add(pm_memusage_t *a, pm_memusage_t *b);
+/* Adds a swap offset */
+void pm_memusage_pswap_add_offset(pm_memusage_t *mu, unsigned int offset);
+/* Enable proportional swap computing. */
+void pm_memusage_pswap_init_handle(pm_memusage_t *mu, pm_proportional_swap_t *p_swap);
+/* Computes and return the proportional swap */
+void pm_memusage_pswap_get_usage(pm_memusage_t *mu, pm_swapusage_t *su);
+void pm_memusage_pswap_free(pm_memusage_t *mu);
+/* Initialize a proportional swap computing handle:
+ assumes only 1 swap device, total swap size of this device in bytes to be given as argument */
+pm_proportional_swap_t * pm_memusage_pswap_create(int swap_size);
+void pm_memusage_pswap_destroy(pm_proportional_swap_t *p_swap);
typedef struct pm_kernel pm_kernel_t;
typedef struct pm_process pm_process_t;
diff --git a/libpagemap/pagemap_test.cpp b/libpagemap/pagemap_test.cpp
index ccbc211..592072c 100644
--- a/libpagemap/pagemap_test.cpp
+++ b/libpagemap/pagemap_test.cpp
@@ -14,10 +14,12 @@
* limitations under the License.
*/
-#include <gtest/gtest.h>
-
#include <pagemap/pagemap.h>
+#include <string>
+
+#include <gtest/gtest.h>
+
TEST(pagemap, maps) {
pm_kernel_t* kernel;
ASSERT_EQ(0, pm_kernel_create(&kernel));
@@ -32,8 +34,9 @@
bool found_heap = false;
bool found_stack = false;
for (size_t i = 0; i < num_maps; i++) {
- if (strcmp(maps[i]->name, "[heap]") == 0) found_heap = true;
- if (strcmp(maps[i]->name, "[stack]") == 0) found_stack = true;
+ std::string name(maps[i]->name);
+ if (name == "[heap]" || name == "[anon:libc_malloc]") found_heap = true;
+ if (name == "[stack]") found_stack = true;
}
ASSERT_TRUE(found_heap);
diff --git a/libpagemap/pm_map.c b/libpagemap/pm_map.c
index c6a1798..301a1cc 100644
--- a/libpagemap/pm_map.c
+++ b/libpagemap/pm_map.c
@@ -42,12 +42,13 @@
if (error) return error;
pm_memusage_zero(&usage);
+ pm_memusage_pswap_init_handle(&usage, usage_out->p_swap);
for (i = 0; i < len; i++) {
usage.vss += map->proc->ker->pagesize;
if (!PM_PAGEMAP_PRESENT(pagemap[i]) &&
- !PM_PAGEMAP_SWAPPED(pagemap[i]))
+ !PM_PAGEMAP_SWAPPED(pagemap[i]))
continue;
if (!PM_PAGEMAP_SWAPPED(pagemap[i])) {
@@ -70,6 +71,7 @@
usage.uss += (count == 1) ? (map->proc->ker->pagesize) : (0);
} else {
usage.swap += map->proc->ker->pagesize;
+ pm_memusage_pswap_add_offset(&usage, PM_PAGEMAP_SWAP_OFFSET(pagemap[i]));
}
}
@@ -77,7 +79,7 @@
error = 0;
-out:
+out:
free(pagemap);
return error;
@@ -101,13 +103,13 @@
if (error) return error;
pm_memusage_zero(&ws);
-
+
for (i = 0; i < len; i++) {
error = pm_kernel_flags(map->proc->ker, PM_PAGEMAP_PFN(pagemap[i]),
&flags);
if (error) goto out;
- if (!(flags & PM_PAGE_REFERENCED))
+ if (!(flags & PM_PAGE_REFERENCED))
continue;
error = pm_kernel_count(map->proc->ker, PM_PAGEMAP_PFN(pagemap[i]),
diff --git a/libpagemap/pm_memusage.c b/libpagemap/pm_memusage.c
index ea2a003..70cfede 100644
--- a/libpagemap/pm_memusage.c
+++ b/libpagemap/pm_memusage.c
@@ -14,10 +14,37 @@
* limitations under the License.
*/
+#include <stdlib.h>
+#include <unistd.h>
+
#include <pagemap/pagemap.h>
+#define SIMPLEQ_INSERT_SIMPLEQ_TAIL(head_a, head_b) \
+ do { \
+ if (!SIMPLEQ_EMPTY(head_b)) { \
+ if ((head_a)->sqh_first == NULL) \
+ (head_a)->sqh_first = (head_b)->sqh_first; \
+ *(head_a)->sqh_last = (head_b)->sqh_first; \
+ (head_a)->sqh_last = (head_b)->sqh_last; \
+ } \
+ } while (/*CONSTCOND*/0)
+
+/* We use an array of int to store the references to a given offset in the swap
+ 1 GiB swap means 512KiB size array: offset are the index */
+typedef unsigned short pm_pswap_refcount_t;
+struct pm_proportional_swap {
+ unsigned int array_size;
+ pm_pswap_refcount_t *offset_array;
+};
+
void pm_memusage_zero(pm_memusage_t *mu) {
mu->vss = mu->rss = mu->pss = mu->uss = mu->swap = 0;
+ mu->p_swap = NULL;
+ SIMPLEQ_INIT(&mu->swap_offset_list);
+}
+
+void pm_memusage_pswap_init_handle(pm_memusage_t *mu, pm_proportional_swap_t *p_swap) {
+ mu->p_swap = p_swap;
}
void pm_memusage_add(pm_memusage_t *a, pm_memusage_t *b) {
@@ -26,4 +53,80 @@
a->pss += b->pss;
a->uss += b->uss;
a->swap += b->swap;
+ SIMPLEQ_INSERT_SIMPLEQ_TAIL(&a->swap_offset_list, &b->swap_offset_list);
+}
+
+pm_proportional_swap_t * pm_memusage_pswap_create(int swap_size)
+{
+ pm_proportional_swap_t *p_swap = NULL;
+
+ p_swap = malloc(sizeof(pm_proportional_swap_t));
+ if (p_swap == NULL) {
+ fprintf(stderr, "Error allocating proportional swap.\n");
+ } else {
+ p_swap->array_size = swap_size / getpagesize();
+ p_swap->offset_array = calloc(p_swap->array_size, sizeof(pm_pswap_refcount_t));
+ if (p_swap->offset_array == NULL) {
+ fprintf(stderr, "Error allocating proportional swap offset array.\n");
+ free(p_swap);
+ p_swap = NULL;
+ }
+ }
+
+ return p_swap;
+}
+
+void pm_memusage_pswap_destroy(pm_proportional_swap_t *p_swap) {
+ if (p_swap) {
+ free(p_swap->offset_array);
+ free(p_swap);
+ }
+}
+
+void pm_memusage_pswap_add_offset(pm_memusage_t *mu, unsigned int offset) {
+ pm_swap_offset_t *soff;
+
+ if (mu->p_swap == NULL)
+ return;
+
+ if (offset > mu->p_swap->array_size) {
+ fprintf(stderr, "SWAP offset %d is out of swap bounds.\n", offset);
+ return;
+ } else {
+ if (mu->p_swap->offset_array[offset] == USHRT_MAX) {
+ fprintf(stderr, "SWAP offset %d ref. count if overflowing ushort type.\n", offset);
+ } else {
+ mu->p_swap->offset_array[offset]++;
+ }
+ }
+
+ soff = malloc(sizeof(pm_swap_offset_t));
+ if (soff) {
+ soff->offset = offset;
+ SIMPLEQ_INSERT_TAIL(&mu->swap_offset_list, soff, simpleqe);
+ }
+}
+
+void pm_memusage_pswap_get_usage(pm_memusage_t *mu, pm_swapusage_t *su) {
+
+ int pagesize = getpagesize();
+ pm_swap_offset_t *elem;
+
+ if (su == NULL)
+ return;
+
+ su->proportional = su->unique = 0;
+ SIMPLEQ_FOREACH(elem, &mu->swap_offset_list, simpleqe) {
+ su->proportional += pagesize / mu->p_swap->offset_array[elem->offset];
+ su->unique += mu->p_swap->offset_array[elem->offset] == 1 ? pagesize : 0;
+ }
+}
+
+void pm_memusage_pswap_free(pm_memusage_t *mu) {
+ pm_swap_offset_t *elem = SIMPLEQ_FIRST(&mu->swap_offset_list);
+ while (elem) {
+ SIMPLEQ_REMOVE_HEAD(&mu->swap_offset_list, simpleqe);
+ free(elem);
+ elem = SIMPLEQ_FIRST(&mu->swap_offset_list);
+ }
}
diff --git a/libpagemap/pm_process.c b/libpagemap/pm_process.c
index b8e06c1..3c5c391 100644
--- a/libpagemap/pm_process.c
+++ b/libpagemap/pm_process.c
@@ -81,6 +81,10 @@
return -1;
pm_memusage_zero(&usage);
+ pm_memusage_pswap_init_handle(&usage, usage_out->p_swap);
+
+ pm_memusage_zero(&map_usage);
+ pm_memusage_pswap_init_handle(&map_usage, usage_out->p_swap);
for (i = 0; i < proc->num_maps; i++) {
error = pm_map_usage_flags(proc->maps[i], &map_usage, flags_mask,
@@ -185,6 +189,11 @@
if (ws_out) {
pm_memusage_zero(&ws);
+ pm_memusage_pswap_init_handle(&ws, ws_out->p_swap);
+
+ pm_memusage_zero(&map_ws);
+ pm_memusage_pswap_init_handle(&map_ws, ws_out->p_swap);
+
for (i = 0; i < proc->num_maps; i++) {
error = pm_map_workingset(proc->maps[i], &map_ws);
if (error) return error;
diff --git a/librank/librank.c b/librank/librank.c
index 28322b9..9d1c026 100644
--- a/librank/librank.c
+++ b/librank/librank.c
@@ -135,7 +135,7 @@
if (library->mappings_count >= library->mappings_size) {
library->mappings = realloc(library->mappings,
- 2 * library->mappings_size * sizeof(struct mapping*));
+ 2 * library->mappings_size * sizeof(struct mapping_info*));
if (!library->mappings) {
fprintf(stderr, "Couldn't resize mappings array: %s\n", strerror(errno));
exit(EXIT_FAILURE);
@@ -316,6 +316,7 @@
libraries = malloc(INIT_LIBRARIES * sizeof(struct library_info *));
libraries_count = 0; libraries_size = INIT_LIBRARIES;
+ pm_memusage_zero(&map_usage);
error = pm_kernel_create(&ker);
if (error) {
@@ -376,7 +377,7 @@
}
}
- printf(" %6s %6s %6s %6s %6s ", "RSStot", "VSS", "RSS", "PSS", "USS");
+ printf(" %6s %7s %6s %6s %6s ", "RSStot", "VSS", "RSS", "PSS", "USS");
if (has_swap) {
printf(" %6s ", "Swap");
@@ -390,7 +391,7 @@
for (i = 0; i < libraries_count; i++) {
li = libraries[i];
- printf("%6zdK %6s %6s %6s %6s ", li->total_usage.pss / 1024, "", "", "", "");
+ printf("%6zdK %7s %6s %6s %6s ", li->total_usage.pss / 1024, "", "", "", "");
if (has_swap) {
printf(" %6s ", "");
}
@@ -402,7 +403,7 @@
for (j = 0; j < li->mappings_count; j++) {
mi = li->mappings[j];
pi = mi->proc;
- printf( " %6s %6zdK %6zdK %6zdK %6zdK ", "",
+ printf( " %6s %7zdK %6zdK %6zdK %6zdK ", "",
mi->usage.vss / 1024,
mi->usage.rss / 1024,
mi->usage.pss / 1024,
diff --git a/postinst/Android.mk b/postinst/Android.mk
index b0dec05..c804cfc 100644
--- a/postinst/Android.mk
+++ b/postinst/Android.mk
@@ -21,10 +21,4 @@
LOCAL_MODULE_TAGS := optional
LOCAL_MODULE_CLASS := EXECUTABLES
LOCAL_SRC_FILES := postinst.sh
-
-# Create a symlink from /postinst to our default post-install script in the
-# same filesystem as /postinst.
-# TODO(deymo): Remove this symlink and add the path to the product config.
-LOCAL_POST_INSTALL_CMD := \
- $(hide) ln -sf bin/postinst_example $(TARGET_OUT)/postinst
include $(BUILD_PREBUILT)
diff --git a/postinst/postinst.sh b/postinst/postinst.sh
index eb98e79..b6a4d20 100644
--- a/postinst/postinst.sh
+++ b/postinst/postinst.sh
@@ -18,18 +18,33 @@
# This is an example post-install script. This script will be executed by the
# update_engine right after finishing writing all the partitions, but before
-# marking the new slot as active.
-#
+# marking the new slot as active. To enable running this program, insert these
+# lines in your product's .mk file (without the # at the beginning):
+
+# AB_OTA_POSTINSTALL_CONFIG += \
+# RUN_POSTINSTALL_system=true \
+# POSTINSTALL_PATH_system=bin/postinst_example \
+# FILESYSTEM_TYPE_system=ext4 \
+
# This script receives no arguments. argv[0] will include the absolute path to
-# the script, including the temporary directory where the new partition was
-# mounted.
+# the script, including the directory where the new partition was mounted.
#
-# This script will run in the context of the old kernel and old system. Note
-# that the absolute path used in first line of this script (/system/bin/sh) is
-# indeed the old system's sh binary. If you use a compiled program, you might
-# want to link it statically or use a wrapper script to use the new ldso to run
-# your program (see the --generate-wrappers option in lddtree.py for example).
-#
+# The script will run from the "postinstall" SELinux domain, from the old system
+# environment (kernel, SELinux rules, etc). New rules and domains introduced by
+# the new system won't be available when this script runs, instead, all the
+# files in the mounted directory will have the attribute "postinstall_file". All
+# the files accessed from here would need to be allowed in the old system or
+# those accesses will fail. For example, the absolute path used in the first
+# line of this script (/system/bin/sh) is indeed the old system's sh binary. If
+# you use a compiled program, you might want to link it statically or use a
+# wrapper script to use the new ldso to run your program (see the
+# --generate-wrappers option in lddtree.py for an example).
+
+my_dir=$(dirname "$0")
+
+echo "The output of this program will show up in the logs." >&2
+echo "Note that this program runs from ${my_dir}"
+
# If the exit code of this program is an error code (different from 0), the
# update will fail and the new slot will not be marked as active.
diff --git a/procmem/procmem.c b/procmem/procmem.c
index 28055d8..17a7212 100644
--- a/procmem/procmem.c
+++ b/procmem/procmem.c
@@ -50,12 +50,12 @@
/* maps and such */
pm_map_t **maps; size_t num_maps;
- struct map_info **mis;
+ struct map_info **mis = NULL;
struct map_info *mi;
/* pagemap information */
uint64_t *pagemap; size_t num_pages;
- unsigned long address; uint64_t mapentry;
+ uint64_t mapentry;
uint64_t count, flags;
/* totals */
@@ -190,7 +190,6 @@
mi->shared_clean = mi->shared_dirty = mi->private_clean = mi->private_dirty = 0;
for (j = 0; j < num_pages; j++) {
- address = pm_map_start(mi->map) + j * ker->pagesize;
mapentry = pagemap[j];
if (PM_PAGEMAP_PRESENT(mapentry) && !PM_PAGEMAP_SWAPPED(mapentry)) {
@@ -298,6 +297,7 @@
);
}
+ free(mis);
return 0;
}
diff --git a/procrank/procrank.c b/procrank/procrank.c
index 1728467..881f110 100644
--- a/procrank/procrank.c
+++ b/procrank/procrank.c
@@ -48,9 +48,26 @@
int (*compfn)(const void *a, const void *b);
static int order;
-void print_mem_info() {
+enum {
+ MEMINFO_TOTAL,
+ MEMINFO_FREE,
+ MEMINFO_BUFFERS,
+ MEMINFO_CACHED,
+ MEMINFO_SHMEM,
+ MEMINFO_SLAB,
+ MEMINFO_SWAP_TOTAL,
+ MEMINFO_SWAP_FREE,
+ MEMINFO_ZRAM_TOTAL,
+ MEMINFO_MAPPED,
+ MEMINFO_VMALLOC_USED,
+ MEMINFO_PAGE_TABLES,
+ MEMINFO_KERNEL_STACK,
+ MEMINFO_COUNT
+};
+
+void get_mem_info(uint64_t mem[]) {
char buffer[1024];
- int numFound = 0;
+ unsigned int numFound = 0;
int fd = open("/proc/meminfo", O_RDONLY);
@@ -75,6 +92,13 @@
"Cached:",
"Shmem:",
"Slab:",
+ "SwapTotal:",
+ "SwapFree:",
+ "ZRam:", /* not read from meminfo but from /sys/block/zram0 */
+ "Mapped:",
+ "VmallocUsed:",
+ "PageTables:",
+ "KernelStack:",
NULL
};
static const int tagsLen[] = {
@@ -84,12 +108,18 @@
7,
6,
5,
+ 10,
+ 9,
+ 5,
+ 7,
+ 12,
+ 11,
+ 12,
0
};
- uint64_t mem[] = { 0, 0, 0, 0, 0, 0 };
char* p = buffer;
- while (*p && numFound < 6) {
+ while (*p && (numFound < (sizeof(tagsLen) / sizeof(tagsLen[0])))) {
int i = 0;
while (tags[i]) {
if (strncmp(p, tags[i], tagsLen[i]) == 0) {
@@ -112,10 +142,6 @@
}
if (*p) p++;
}
-
- printf("RAM: %" PRIu64 "K total, %" PRIu64 "K free, %" PRIu64 "K buffers, "
- "%" PRIu64 "K cached, %" PRIu64 "K shmem, %" PRIu64 "K slab\n",
- mem[0], mem[1], mem[2], mem[3], mem[4], mem[5]);
}
int main(int argc, char *argv[]) {
@@ -127,9 +153,12 @@
uint64_t total_pss;
uint64_t total_uss;
uint64_t total_swap;
+ uint64_t total_pswap;
+ uint64_t total_uswap;
+ uint64_t total_zswap;
char cmdline[256]; // this must be within the range of int
int error;
- bool has_swap = false;
+ bool has_swap = false, has_zram = false;
uint64_t required_flags = 0;
uint64_t flags_mask = 0;
@@ -141,6 +170,12 @@
int arg;
size_t i, j;
+ uint64_t mem[MEMINFO_COUNT] = { };
+ pm_proportional_swap_t *p_swap;
+ int fd, len;
+ char buffer[1024];
+ float zram_cr = 0.0;
+
signal(SIGPIPE, SIG_IGN);
compfn = &sort_by_pss;
order = -1;
@@ -164,6 +199,9 @@
exit(EXIT_FAILURE);
}
+ get_mem_info(mem);
+ p_swap = pm_memusage_pswap_create(mem[MEMINFO_SWAP_TOTAL] * 1024);
+
error = pm_kernel_create(&ker);
if (error) {
fprintf(stderr, "Error creating kernel interface -- "
@@ -191,6 +229,7 @@
}
procs[i]->pid = pids[i];
pm_memusage_zero(&procs[i]->usage);
+ pm_memusage_pswap_init_handle(&procs[i]->usage, p_swap);
error = pm_process_create(ker, pids[i], &proc);
if (error) {
fprintf(stderr, "warning: could not create process interface for %d\n", pids[i]);
@@ -237,16 +276,37 @@
qsort(procs, num_procs, sizeof(procs[0]), compfn);
+ if (has_swap) {
+ fd = open("/sys/block/zram0/mem_used_total", O_RDONLY);
+ if (fd >= 0) {
+ len = read(fd, buffer, sizeof(buffer)-1);
+ close(fd);
+ if (len > 0) {
+ buffer[len] = 0;
+ mem[MEMINFO_ZRAM_TOTAL] = atoll(buffer)/1024;
+ zram_cr = (float) mem[MEMINFO_ZRAM_TOTAL] /
+ (mem[MEMINFO_SWAP_TOTAL] - mem[MEMINFO_SWAP_FREE]);
+ has_zram = true;
+ }
+ }
+ }
+
printf("%5s ", "PID");
if (ws) {
- printf("%s %7s %7s ", "WRss", "WPss", "WUss");
+ printf("%7s %7s %7s ", "WRss", "WPss", "WUss");
if (has_swap) {
- printf("%7s ", "WSwap");
+ printf("%7s %7s %7s ", "WSwap", "WPSwap", "WUSwap");
+ if (has_zram) {
+ printf("%7s ", "WZSwap");
+ }
}
} else {
printf("%8s %7s %7s %7s ", "Vss", "Rss", "Pss", "Uss");
if (has_swap) {
- printf("%7s ", "Swap");
+ printf("%7s %7s %7s ", "Swap", "PSwap", "USwap");
+ if (has_zram) {
+ printf("%7s ", "ZSwap");
+ }
}
}
@@ -255,6 +315,9 @@
total_pss = 0;
total_uss = 0;
total_swap = 0;
+ total_pswap = 0;
+ total_uswap = 0;
+ total_zswap = 0;
for (i = 0; i < num_procs; i++) {
if (getprocname(procs[i]->pid, cmdline, (int)sizeof(cmdline)) < 0) {
@@ -288,7 +351,20 @@
}
if (has_swap) {
+ pm_swapusage_t su;
+
+ pm_memusage_pswap_get_usage(&procs[i]->usage, &su);
printf("%6zuK ", procs[i]->usage.swap / 1024);
+ printf("%6zuK ", su.proportional / 1024);
+ printf("%6zuK ", su.unique / 1024);
+ total_pswap += su.proportional;
+ total_uswap += su.unique;
+ pm_memusage_pswap_free(&procs[i]->usage);
+ if (has_zram) {
+ size_t zpswap = su.proportional * zram_cr;
+ printf("%6zuK ", zpswap / 1024);
+ total_zswap += zpswap;
+ }
}
printf("%s\n", cmdline);
@@ -297,6 +373,7 @@
}
free(procs);
+ pm_memusage_pswap_destroy(p_swap);
/* Print the separator line */
printf("%5s ", "");
@@ -308,7 +385,10 @@
}
if (has_swap) {
- printf("%7s ", "------");
+ printf("%7s %7s %7s ", "------", "------", "------");
+ if (has_zram) {
+ printf("%7s ", "------");
+ }
}
printf("%s\n", "------");
@@ -316,7 +396,7 @@
/* Print the total line */
printf("%5s ", "");
if (ws) {
- printf("%7s %6" PRIu64 "K %" PRIu64 "K ",
+ printf("%7s %6" PRIu64 "K %6" PRIu64 "K ",
"", total_pss / 1024, total_uss / 1024);
} else {
printf("%8s %7s %6" PRIu64 "K %6" PRIu64 "K ",
@@ -325,12 +405,27 @@
if (has_swap) {
printf("%6" PRIu64 "K ", total_swap / 1024);
+ printf("%6" PRIu64 "K ", total_pswap / 1024);
+ printf("%6" PRIu64 "K ", total_uswap / 1024);
+ if (has_zram) {
+ printf("%6" PRIu64 "K ", total_zswap / 1024);
+ }
}
printf("TOTAL\n");
printf("\n");
- print_mem_info();
+
+ if (has_swap) {
+ printf("ZRAM: %" PRIu64 "K physical used for %" PRIu64 "K in swap "
+ "(%" PRIu64 "K total swap)\n",
+ mem[MEMINFO_ZRAM_TOTAL], (mem[MEMINFO_SWAP_TOTAL] - mem[MEMINFO_SWAP_FREE]),
+ mem[MEMINFO_SWAP_TOTAL]);
+ }
+ printf(" RAM: %" PRIu64 "K total, %" PRIu64 "K free, %" PRIu64 "K buffers, "
+ "%" PRIu64 "K cached, %" PRIu64 "K shmem, %" PRIu64 "K slab\n",
+ mem[MEMINFO_TOTAL], mem[MEMINFO_FREE], mem[MEMINFO_BUFFERS],
+ mem[MEMINFO_CACHED], mem[MEMINFO_SHMEM], mem[MEMINFO_SLAB]);
return 0;
}
diff --git a/simpleperf/Android.mk b/simpleperf/Android.mk
index 7e2eae4..1ec496e 100644
--- a/simpleperf/Android.mk
+++ b/simpleperf/Android.mk
@@ -182,8 +182,7 @@
LOCAL_MODULE_PATH := $(TARGET_OUT_OPTIONAL_EXECUTABLES)
LOCAL_CPPFLAGS := $(simpleperf_cppflags_target)
LOCAL_SRC_FILES := main.cpp
-LOCAL_WHOLE_STATIC_LIBRARIES := libsimpleperf
-LOCAL_STATIC_LIBRARIES := $(simpleperf_static_libraries_target)
+LOCAL_STATIC_LIBRARIES := libsimpleperf $(simpleperf_static_libraries_target)
LOCAL_SHARED_LIBRARIES := $(simpleperf_shared_libraries_target)
LOCAL_MULTILIB := first
include $(BUILD_EXECUTABLE)
@@ -194,8 +193,7 @@
LOCAL_MODULE := simpleperf_static
LOCAL_CPPFLAGS := $(simpleperf_cppflags_target)
LOCAL_SRC_FILES := main.cpp
-LOCAL_WHOLE_STATIC_LIBRARIES := libsimpleperf_static
-LOCAL_STATIC_LIBRARIES := $(static_simpleperf_static_libraries_target)
+LOCAL_STATIC_LIBRARIES := libsimpleperf_static $(static_simpleperf_static_libraries_target)
LOCAL_MULTILIB := first
LOCAL_FORCE_STATIC_EXECUTABLE := true
include $(LLVM_DEVICE_BUILD_MK)
@@ -210,8 +208,7 @@
LOCAL_CPPFLAGS_linux := $(simpleperf_cppflags_host_linux)
LOCAL_CPPFLAGS_windows := $(simpleperf_cppflags_host_windows)
LOCAL_SRC_FILES := main.cpp
-LOCAL_WHOLE_STATIC_LIBRARIES := libsimpleperf
-LOCAL_STATIC_LIBRARIES := $(simpleperf_static_libraries_host)
+LOCAL_STATIC_LIBRARIES := libsimpleperf $(simpleperf_static_libraries_host)
LOCAL_SHARED_LIBRARIES_darwin := $(simpleperf_shared_libraries_host_darwin)
LOCAL_SHARED_LIBRARIES_linux := $(simpleperf_shared_libraries_host_linux)
LOCAL_SHARED_LIBRARIES_windows := $(simpleperf_shared_libraries_host_windows)
@@ -223,6 +220,7 @@
# simpleperf_unit_test
# =========================================================
simpleperf_unit_test_src_files := \
+ cmd_report_test.cpp \
command_test.cpp \
gtest_main.cpp \
read_apk_test.cpp \
@@ -234,7 +232,6 @@
cmd_dumprecord_test.cpp \
cmd_list_test.cpp \
cmd_record_test.cpp \
- cmd_report_test.cpp \
cmd_stat_test.cpp \
environment_test.cpp \
record_file_test.cpp \
@@ -249,8 +246,7 @@
$(simpleperf_unit_test_src_files) \
$(simpleperf_unit_test_src_files_linux) \
-LOCAL_WHOLE_STATIC_LIBRARIES := libsimpleperf
-LOCAL_STATIC_LIBRARIES += $(simpleperf_static_libraries_target)
+LOCAL_STATIC_LIBRARIES += libsimpleperf $(simpleperf_static_libraries_target)
LOCAL_SHARED_LIBRARIES := $(simpleperf_shared_libraries_target)
LOCAL_MULTILIB := first
include $(BUILD_NATIVE_TEST)
@@ -265,8 +261,7 @@
LOCAL_CPPFLAGS_windows := $(simpleperf_cppflags_host_windows)
LOCAL_SRC_FILES := $(simpleperf_unit_test_src_files)
LOCAL_SRC_FILES_linux := $(simpleperf_unit_test_src_files_linux)
-LOCAL_WHOLE_STATIC_LIBRARIES := libsimpleperf
-LOCAL_STATIC_LIBRARIES := $(simpleperf_static_libraries_host)
+LOCAL_STATIC_LIBRARIES := libsimpleperf $(simpleperf_static_libraries_host)
LOCAL_SHARED_LIBRARIES_darwin := $(simpleperf_shared_libraries_host_darwin)
LOCAL_SHARED_LIBRARIES_linux := $(simpleperf_shared_libraries_host_linux)
LOCAL_SHARED_LIBRARIES_windows := $(simpleperf_shared_libraries_host_windows)
@@ -286,9 +281,8 @@
LOCAL_MODULE := simpleperf_cpu_hotplug_test
LOCAL_CPPFLAGS := $(simpleperf_cppflags_target)
LOCAL_SRC_FILES := $(simpleperf_cpu_hotplug_test_src_files)
-LOCAL_WHOLE_STATIC_LIBRARIES := libsimpleperf
+LOCAL_STATIC_LIBRARIES := libsimpleperf $(simpleperf_static_libraries_target)
LOCAL_SHARED_LIBRARIES := $(simpleperf_shared_libraries_target)
-LOCAL_STATIC_LIBRARIES := $(simpleperf_static_libraries_target)
LOCAL_MULTILIB := first
include $(BUILD_NATIVE_TEST)
@@ -300,11 +294,46 @@
LOCAL_CPPFLAGS := $(simpleperf_cppflags_host)
LOCAL_CPPFLAGS_linux := $(simpleperf_cppflags_host_linux)
LOCAL_SRC_FILES := $(simpleperf_cpu_hotplug_test_src_files)
-LOCAL_WHOLE_STATIC_LIBRARIES := libsimpleperf
-LOCAL_STATIC_LIBRARIES := $(simpleperf_static_libraries_host)
+LOCAL_STATIC_LIBRARIES := libsimpleperf $(simpleperf_static_libraries_host)
LOCAL_SHARED_LIBRARIES_linux := $(simpleperf_shared_libraries_host_linux)
LOCAL_LDLIBS_linux := $(simpleperf_ldlibs_host_linux)
LOCAL_MULTILIB := first
include $(BUILD_HOST_NATIVE_TEST)
+
+# libsimpleperf_cts_test
+# =========================================================
+libsimpleperf_cts_test_src_files := \
+ $(libsimpleperf_src_files) \
+ $(libsimpleperf_src_files_linux) \
+ $(simpleperf_unit_test_src_files) \
+ $(simpleperf_unit_test_src_files_linux) \
+
+# libsimpleperf_cts_test target
+include $(CLEAR_VARS)
+LOCAL_CLANG := true
+LOCAL_MODULE := libsimpleperf_cts_test
+LOCAL_CPPFLAGS := $(simpleperf_cppflags_target) -DIN_CTS_TEST
+LOCAL_SRC_FILES := $(libsimpleperf_cts_test_src_files)
+LOCAL_STATIC_LIBRARIES := $(simpleperf_static_libraries_target)
+LOCAL_SHARED_LIBRARIES := $(simpleperf_shared_libraries_target)
+LOCAL_MULTILIB := both
+include $(LLVM_DEVICE_BUILD_MK)
+include $(BUILD_STATIC_TEST_LIBRARY)
+
+# libsimpleperf_cts_test linux host
+include $(CLEAR_VARS)
+LOCAL_CLANG := true
+LOCAL_MODULE := libsimpleperf_cts_test
+LOCAL_MODULE_HOST_OS := linux
+LOCAL_CPPFLAGS := $(simpleperf_cppflags_host) -DIN_CTS_TEST
+LOCAL_CPPFLAGS_linux := $(simpleperf_cppflags_host_linux)
+LOCAL_SRC_FILES := $(libsimpleperf_cts_test_src_files)
+LOCAL_STATIC_LIBRARIES := $(simpleperf_static_libraries_host)
+LOCAL_SHARED_LIBRARIES_linux := $(simpleperf_shared_libraries_host_linux)
+LOCAL_LDLIBS_linux := $(simpleperf_ldlibs_host_linux)
+LOCAL_MULTILIB := both
+include $(LLVM_HOST_BUILD_MK)
+include $(BUILD_HOST_STATIC_TEST_LIBRARY)
+
include $(call first-makefiles-under,$(LOCAL_PATH))
diff --git a/simpleperf/build_id.h b/simpleperf/build_id.h
index 05c37d5..bbd13c4 100644
--- a/simpleperf/build_id.h
+++ b/simpleperf/build_id.h
@@ -33,10 +33,29 @@
memset(data_, '\0', BUILD_ID_SIZE);
}
- BuildId(const void* data, size_t len = BUILD_ID_SIZE) : BuildId() {
+ // Copy build id from a byte array, like {0x76, 0x00, 0x32,...}.
+ BuildId(const void* data, size_t len) : BuildId() {
memcpy(data_, data, std::min(len, BUILD_ID_SIZE));
}
+ // Read build id from a hex string, like "7600329e31058e12b145d153ef27cd40e1a5f7b9".
+ explicit BuildId(const std::string& s) : BuildId() {
+ for (size_t i = 0; i < s.size() && i < BUILD_ID_SIZE * 2; i += 2) {
+ unsigned char ch = 0;
+ for (size_t j = i; j < i + 2; ++j) {
+ ch <<= 4;
+ if (s[j] >= '0' && s[j] <= '9') {
+ ch |= s[j] - '0';
+ } else if (s[j] >= 'a' && s[j] <= 'f') {
+ ch |= s[j] - 'a' + 10;
+ } else if (s[j] >= 'A' && s[j] <= 'F') {
+ ch |= s[j] - 'A' + 10;
+ }
+ }
+ data_[i / 2] = ch;
+ }
+ }
+
const unsigned char* Data() const {
return data_;
}
@@ -53,6 +72,10 @@
return memcmp(data_, build_id.data_, BUILD_ID_SIZE) == 0;
}
+ bool operator!=(const BuildId& build_id) const {
+ return !(*this == build_id);
+ }
+
bool IsEmpty() const {
static BuildId empty_build_id;
return *this == empty_build_id;
diff --git a/simpleperf/cmd_dumprecord.cpp b/simpleperf/cmd_dumprecord.cpp
index c6af552..13f40d4 100644
--- a/simpleperf/cmd_dumprecord.cpp
+++ b/simpleperf/cmd_dumprecord.cpp
@@ -207,6 +207,6 @@
}
}
-__attribute__((constructor)) static void RegisterDumpRecordCommand() {
+void RegisterDumpRecordCommand() {
RegisterCommand("dump", [] { return std::unique_ptr<Command>(new DumpRecordCommand); });
}
diff --git a/simpleperf/cmd_dumprecord_test.cpp b/simpleperf/cmd_dumprecord_test.cpp
index f23ae16..441851f 100644
--- a/simpleperf/cmd_dumprecord_test.cpp
+++ b/simpleperf/cmd_dumprecord_test.cpp
@@ -17,26 +17,12 @@
#include <gtest/gtest.h>
#include "command.h"
+#include "get_test_data.h"
-class DumpRecordCommandTest : public ::testing::Test {
- protected:
- virtual void SetUp() {
- record_cmd = CreateCommandInstance("record");
- ASSERT_TRUE(record_cmd != nullptr);
- dumprecord_cmd = CreateCommandInstance("dump");
- ASSERT_TRUE(dumprecord_cmd != nullptr);
- }
-
- std::unique_ptr<Command> record_cmd;
- std::unique_ptr<Command> dumprecord_cmd;
-};
-
-TEST_F(DumpRecordCommandTest, no_options) {
- ASSERT_TRUE(record_cmd->Run({"-a", "sleep", "1"}));
- ASSERT_TRUE(dumprecord_cmd->Run({}));
+static std::unique_ptr<Command> DumpCmd() {
+ return CreateCommandInstance("dump");
}
-TEST_F(DumpRecordCommandTest, record_file_option) {
- ASSERT_TRUE(record_cmd->Run({"-a", "-o", "perf2.data", "sleep", "1"}));
- ASSERT_TRUE(dumprecord_cmd->Run({"perf2.data"}));
+TEST(cmd_dump, record_file_option) {
+ ASSERT_TRUE(DumpCmd()->Run({GetTestData("perf.data")}));
}
diff --git a/simpleperf/cmd_help.cpp b/simpleperf/cmd_help.cpp
index a29ef72..0bb6231 100644
--- a/simpleperf/cmd_help.cpp
+++ b/simpleperf/cmd_help.cpp
@@ -60,7 +60,8 @@
"common options:\n"
" -h/--help Print this help information.\n"
" --log <severity> Set the minimum severity of logging. Possible severities\n"
- " include debug, warning, error, fatal. Default is warning.\n"
+ " include verbose, debug, warning, error, fatal. Default is\n"
+ " warning.\n"
"subcommands:\n");
for (auto& cmd_name : GetAllCommandNames()) {
std::unique_ptr<Command> cmd = CreateCommandInstance(cmd_name);
@@ -72,6 +73,6 @@
printf("%s\n", command.LongHelpString().c_str());
}
-__attribute__((constructor)) static void RegisterHelpCommand() {
+void RegisterHelpCommand() {
RegisterCommand("help", [] { return std::unique_ptr<Command>(new HelpCommand); });
}
diff --git a/simpleperf/cmd_list.cpp b/simpleperf/cmd_list.cpp
index 01ac048..b6bf817 100644
--- a/simpleperf/cmd_list.cpp
+++ b/simpleperf/cmd_list.cpp
@@ -82,6 +82,6 @@
return true;
}
-__attribute__((constructor)) static void RegisterListCommand() {
+void RegisterListCommand() {
RegisterCommand("list", [] { return std::unique_ptr<Command>(new ListCommand); });
}
diff --git a/simpleperf/cmd_record.cpp b/simpleperf/cmd_record.cpp
index edeb64c..f0a7cea 100644
--- a/simpleperf/cmd_record.cpp
+++ b/simpleperf/cmd_record.cpp
@@ -126,7 +126,8 @@
post_unwind_(false),
child_inherit_(true),
perf_mmap_pages_(16),
- record_filename_("perf.data") {
+ record_filename_("perf.data"),
+ sample_record_count_(0) {
signaled = false;
scoped_signal_handler_.reset(
new ScopedSignalHandler({SIGCHLD, SIGINT, SIGTERM}, signal_handler));
@@ -146,6 +147,7 @@
bool DumpThreadCommAndMmaps(bool all_threads, const std::vector<pid_t>& selected_threads);
bool CollectRecordsFromKernel(const char* data, size_t size);
bool ProcessRecord(Record* record);
+ void UpdateRecordForEmbeddedElfPath(Record* record);
void UnwindRecord(Record* record);
bool PostUnwind(const std::vector<std::string>& args);
bool DumpAdditionalFeatures(const std::vector<std::string>& args);
@@ -153,7 +155,6 @@
void CollectHitFileInfo(Record* record);
std::pair<std::string, uint64_t> TestForEmbeddedElf(Dso *dso, uint64_t pgoff);
-
bool use_sample_freq_; // Use sample_freq_ when true, otherwise using sample_period_.
uint64_t sample_freq_; // Sample 'sample_freq_' times per second.
uint64_t sample_period_; // Sample once when 'sample_period_' events occur.
@@ -180,10 +181,10 @@
std::unique_ptr<RecordFileWriter> record_file_writer_;
std::set<std::string> hit_kernel_modules_;
- std::set<std::pair<std::string, uint64_t> > hit_user_files_;
- ApkInspector apk_inspector_;
+ std::set<std::string> hit_user_files_;
std::unique_ptr<ScopedSignalHandler> scoped_signal_handler_;
+ uint64_t sample_record_count_;
};
bool RecordCommand::Run(const std::vector<std::string>& args) {
@@ -280,7 +281,7 @@
return false;
}
}
-
+ LOG(VERBOSE) << "Record " << sample_record_count_ << " samples.";
return true;
}
@@ -637,15 +638,52 @@
}
bool RecordCommand::ProcessRecord(Record* record) {
+ UpdateRecordForEmbeddedElfPath(record);
BuildThreadTree(*record, &thread_tree_);
CollectHitFileInfo(record);
if (unwind_dwarf_callchain_ && !post_unwind_) {
UnwindRecord(record);
}
+ if (record->type() == PERF_RECORD_SAMPLE) {
+ sample_record_count_++;
+ }
bool result = record_file_writer_->WriteData(record->BinaryFormat());
return result;
}
+template<class RecordType>
+void UpdateMmapRecordForEmbeddedElfPath(RecordType* record) {
+ RecordType& r = *record;
+ bool in_kernel = ((r.header.misc & PERF_RECORD_MISC_CPUMODE_MASK) == PERF_RECORD_MISC_KERNEL);
+ if (!in_kernel && r.data.pgoff != 0) {
+ // For the case of a shared library "foobar.so" embedded
+ // inside an APK, we rewrite the original MMAP from
+ // ["path.apk" offset=X] to ["path.apk!/foobar.so" offset=W]
+ // so as to make the library name explicit. This update is
+ // done here (as part of the record operation) as opposed to
+ // on the host during the report, since we want to report
+ // the correct library name even if the the APK in question
+ // is not present on the host. The new offset W is
+ // calculated to be with respect to the start of foobar.so,
+ // not to the start of path.apk.
+ EmbeddedElf* ee = ApkInspector::FindElfInApkByOffset(r.filename, r.data.pgoff);
+ if (ee != nullptr) {
+ // Compute new offset relative to start of elf in APK.
+ r.data.pgoff -= ee->entry_offset();
+ r.filename = GetUrlInApk(r.filename, ee->entry_name());
+ r.AdjustSizeBasedOnData();
+ }
+ }
+}
+
+void RecordCommand::UpdateRecordForEmbeddedElfPath(Record* record) {
+ if (record->type() == PERF_RECORD_MMAP) {
+ UpdateMmapRecordForEmbeddedElfPath(static_cast<MmapRecord*>(record));
+ } else if (record->type() == PERF_RECORD_MMAP2) {
+ UpdateMmapRecordForEmbeddedElfPath(static_cast<Mmap2Record*>(record));
+ }
+}
+
void RecordCommand::UnwindRecord(Record* record) {
if (record->type() == PERF_RECORD_SAMPLE) {
SampleRecord& r = *static_cast<SampleRecord*>(record);
@@ -746,7 +784,7 @@
std::vector<BuildIdRecord> build_id_records;
BuildId build_id;
// Add build_ids for kernel/modules.
- for (auto& filename : hit_kernel_modules_) {
+ for (const auto& filename : hit_kernel_modules_) {
if (filename == DEFAULT_KERNEL_FILENAME_FOR_BUILD_ID) {
if (!GetKernelBuildId(&build_id)) {
LOG(DEBUG) << "can't read build_id for kernel";
@@ -768,29 +806,21 @@
}
}
// Add build_ids for user elf files.
- for (auto& dso_origin : hit_user_files_) {
- auto& filename = dso_origin.first;
- auto& offset = dso_origin.second;
+ for (const auto& filename : hit_user_files_) {
if (filename == DEFAULT_EXECNAME_FOR_THREAD_MMAP) {
continue;
}
- EmbeddedElf *ee = apk_inspector_.FindElfInApkByMmapOffset(filename, offset);
- if (ee) {
- if (!GetBuildIdFromEmbeddedElfFile(filename,
- ee->entry_offset(),
- ee->entry_size(),
- &build_id)) {
- LOG(DEBUG) << "can't read build_id from archive file " << filename
- << "entry " << ee->entry_name();
+ auto tuple = SplitUrlInApk(filename);
+ if (std::get<0>(tuple)) {
+ if (!GetBuildIdFromApkFile(std::get<1>(tuple), std::get<2>(tuple), &build_id)) {
+ LOG(DEBUG) << "can't read build_id from file " << filename;
continue;
}
- std::string ee_filename = filename + "!" + ee->entry_name();
- build_id_records.push_back(CreateBuildIdRecord(false, UINT_MAX, build_id, ee_filename));
- continue;
- }
- if (!GetBuildIdFromElfFile(filename, &build_id)) {
- LOG(DEBUG) << "can't read build_id from file " << filename;
- continue;
+ } else {
+ if (!GetBuildIdFromElfFile(filename, &build_id)) {
+ LOG(DEBUG) << "can't read build_id from file " << filename;
+ continue;
+ }
}
build_id_records.push_back(CreateBuildIdRecord(false, UINT_MAX, build_id, filename));
}
@@ -809,53 +839,11 @@
if (in_kernel) {
hit_kernel_modules_.insert(map->dso->Path());
} else {
- auto apair = std::make_pair(map->dso->Path(), map->pgoff);
- hit_user_files_.insert(apair);
- }
- }
- if (record->type() == PERF_RECORD_MMAP) {
- MmapRecord& r = *static_cast<MmapRecord*>(record);
- bool in_kernel = ((r.header.misc & PERF_RECORD_MISC_CPUMODE_MASK) == PERF_RECORD_MISC_KERNEL);
- if (!in_kernel) {
- const ThreadEntry* thread = thread_tree_.FindThreadOrNew(r.data.pid, r.data.tid);
- const MapEntry* map = thread_tree_.FindMap(thread, r.data.addr, in_kernel);
- if (map->pgoff != 0u) {
- std::pair<std::string, uint64_t> ee_info = TestForEmbeddedElf(map->dso, map->pgoff);
- if (!ee_info.first.empty()) {
- // For the case of a shared library "foobar.so" embedded
- // inside an APK, we rewrite the original MMAP from
- // ["path.apk" offset=X] to ["path.apk!foobar.so" offset=W]
- // so as to make the library name explicit. This update is
- // done here (as part of the record operation) as opposed to
- // on the host during the report, since we want to report
- // the correct library name even if the the APK in question
- // is not present on the host. The new offset W is
- // calculated to be with respect to the start of foobar.so,
- // not to the start of path.apk.
- const std::string& entry_name = ee_info.first;
- uint64_t new_offset = ee_info.second;
- std::string new_filename = r.filename + "!" + entry_name;
- UpdateMmapRecord(&r, new_filename, new_offset);
- }
- }
+ hit_user_files_.insert(map->dso->Path());
}
}
}
-std::pair<std::string, uint64_t> RecordCommand::TestForEmbeddedElf(Dso *dso, uint64_t pgoff)
-{
- // Examine the DSO to determine whether it corresponds to an ELF
- // file embedded in an APK.
- std::string ee_name;
- EmbeddedElf *ee = apk_inspector_.FindElfInApkByMmapOffset(dso->Path(), pgoff);
- if (ee) {
- // Compute new offset relative to start of elf in APK.
- uint64_t elf_offset = pgoff - ee->entry_offset();
- return std::make_pair(ee->entry_name(), elf_offset);
- }
- return std::make_pair(std::string(), 0u);
-}
-
-__attribute__((constructor)) static void RegisterRecordCommand() {
+void RegisterRecordCommand() {
RegisterCommand("record", [] { return std::unique_ptr<Command>(new RecordCommand()); });
}
diff --git a/simpleperf/cmd_record_test.cpp b/simpleperf/cmd_record_test.cpp
index ca38a05..a89febf 100644
--- a/simpleperf/cmd_record_test.cpp
+++ b/simpleperf/cmd_record_test.cpp
@@ -17,10 +17,14 @@
#include <gtest/gtest.h>
#include <android-base/stringprintf.h>
+#include <android-base/test_utils.h>
+
+#include <memory>
#include "command.h"
#include "environment.h"
#include "event_selection_set.h"
+#include "get_test_data.h"
#include "record.h"
#include "record_file.h"
#include "test_util.h"
@@ -31,34 +35,51 @@
return CreateCommandInstance("record");
}
+static bool RunRecordCmd(std::vector<std::string> v, const char* output_file = nullptr) {
+ std::unique_ptr<TemporaryFile> tmpfile;
+ std::string out_file;
+ if (output_file != nullptr) {
+ out_file = output_file;
+ } else {
+ tmpfile.reset(new TemporaryFile);
+ out_file = tmpfile->path;
+ }
+ v.insert(v.end(), {"-o", out_file, "sleep", SLEEP_SEC});
+ return RecordCmd()->Run(v);
+}
+
TEST(record_cmd, no_options) {
- ASSERT_TRUE(RecordCmd()->Run({"sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({}));
}
TEST(record_cmd, system_wide_option) {
- ASSERT_TRUE(RecordCmd()->Run({"-a", "sleep", "1"}));
+ if (IsRoot()) {
+ ASSERT_TRUE(RunRecordCmd({"-a"}));
+ }
}
TEST(record_cmd, sample_period_option) {
- ASSERT_TRUE(RecordCmd()->Run({"-c", "100000", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"-c", "100000"}));
}
TEST(record_cmd, event_option) {
- ASSERT_TRUE(RecordCmd()->Run({"-e", "cpu-clock", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"-e", "cpu-clock"}));
}
TEST(record_cmd, freq_option) {
- ASSERT_TRUE(RecordCmd()->Run({"-f", "99", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"-F", "99", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"-f", "99"}));
+ ASSERT_TRUE(RunRecordCmd({"-F", "99"}));
}
TEST(record_cmd, output_file_option) {
- ASSERT_TRUE(RecordCmd()->Run({"-o", "perf2.data", "sleep", "1"}));
+ TemporaryFile tmpfile;
+ ASSERT_TRUE(RecordCmd()->Run({"-o", tmpfile.path, "sleep", SLEEP_SEC}));
}
TEST(record_cmd, dump_kernel_mmap) {
- ASSERT_TRUE(RecordCmd()->Run({"sleep", "1"}));
- std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance("perf.data");
+ TemporaryFile tmpfile;
+ ASSERT_TRUE(RunRecordCmd({}, tmpfile.path));
+ std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(tmpfile.path);
ASSERT_TRUE(reader != nullptr);
std::vector<std::unique_ptr<Record>> records = reader->DataSection();
ASSERT_GT(records.size(), 0U);
@@ -76,8 +97,9 @@
}
TEST(record_cmd, dump_build_id_feature) {
- ASSERT_TRUE(RecordCmd()->Run({"sleep", "1"}));
- std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance("perf.data");
+ TemporaryFile tmpfile;
+ ASSERT_TRUE(RunRecordCmd({}, tmpfile.path));
+ std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(tmpfile.path);
ASSERT_TRUE(reader != nullptr);
const FileHeader& file_header = reader->FileHeader();
ASSERT_TRUE(file_header.features[FEAT_BUILD_ID / 8] & (1 << (FEAT_BUILD_ID % 8)));
@@ -85,16 +107,18 @@
}
TEST(record_cmd, tracepoint_event) {
- ASSERT_TRUE(RecordCmd()->Run({"-a", "-e", "sched:sched_switch", "sleep", "1"}));
+ if (IsRoot()) {
+ ASSERT_TRUE(RunRecordCmd({"-a", "-e", "sched:sched_switch"}));
+ }
}
TEST(record_cmd, branch_sampling) {
if (IsBranchSamplingSupported()) {
- ASSERT_TRUE(RecordCmd()->Run({"-a", "-b", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"-j", "any,any_call,any_ret,ind_call", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"-j", "any,k", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"-j", "any,u", "sleep", "1"}));
- ASSERT_FALSE(RecordCmd()->Run({"-j", "u", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"-b"}));
+ ASSERT_TRUE(RunRecordCmd({"-j", "any,any_call,any_ret,ind_call"}));
+ ASSERT_TRUE(RunRecordCmd({"-j", "any,k"}));
+ ASSERT_TRUE(RunRecordCmd({"-j", "any,u"}));
+ ASSERT_FALSE(RunRecordCmd({"-j", "u"}));
} else {
GTEST_LOG_(INFO)
<< "This test does nothing as branch stack sampling is not supported on this device.";
@@ -102,18 +126,18 @@
}
TEST(record_cmd, event_modifier) {
- ASSERT_TRUE(RecordCmd()->Run({"-e", "cpu-cycles:u", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"-e", "cpu-cycles:u"}));
}
TEST(record_cmd, fp_callchain_sampling) {
- ASSERT_TRUE(RecordCmd()->Run({"--call-graph", "fp", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"--call-graph", "fp"}));
}
TEST(record_cmd, dwarf_callchain_sampling) {
if (IsDwarfCallChainSamplingSupported()) {
- ASSERT_TRUE(RecordCmd()->Run({"--call-graph", "dwarf", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"--call-graph", "dwarf,16384", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"-g", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"--call-graph", "dwarf"}));
+ ASSERT_TRUE(RunRecordCmd({"--call-graph", "dwarf,16384"}));
+ ASSERT_TRUE(RunRecordCmd({"-g"}));
} else {
GTEST_LOG_(INFO)
<< "This test does nothing as dwarf callchain sampling is not supported on this device.";
@@ -122,24 +146,24 @@
TEST(record_cmd, no_unwind_option) {
if (IsDwarfCallChainSamplingSupported()) {
- ASSERT_TRUE(RecordCmd()->Run({"--call-graph", "dwarf", "--no-unwind", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"--call-graph", "dwarf", "--no-unwind"}));
} else {
GTEST_LOG_(INFO)
<< "This test does nothing as dwarf callchain sampling is not supported on this device.";
}
- ASSERT_FALSE(RecordCmd()->Run({"--no-unwind", "sleep", "1"}));
+ ASSERT_FALSE(RunRecordCmd({"--no-unwind"}));
}
TEST(record_cmd, post_unwind_option) {
if (IsDwarfCallChainSamplingSupported()) {
- ASSERT_TRUE(RecordCmd()->Run({"--call-graph", "dwarf", "--post-unwind", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"--call-graph", "dwarf", "--post-unwind"}));
} else {
GTEST_LOG_(INFO)
<< "This test does nothing as dwarf callchain sampling is not supported on this device.";
}
- ASSERT_FALSE(RecordCmd()->Run({"--post-unwind", "sleep", "1"}));
+ ASSERT_FALSE(RunRecordCmd({"--post-unwind"}));
ASSERT_FALSE(
- RecordCmd()->Run({"--call-graph", "dwarf", "--no-unwind", "--post-unwind", "sleep", "1"}));
+ RunRecordCmd({"--call-graph", "dwarf", "--no-unwind", "--post-unwind"}));
}
TEST(record_cmd, existing_processes) {
@@ -147,7 +171,8 @@
CreateProcesses(2, &workloads);
std::string pid_list =
android::base::StringPrintf("%d,%d", workloads[0]->GetPid(), workloads[1]->GetPid());
- ASSERT_TRUE(RecordCmd()->Run({"-p", pid_list}));
+ TemporaryFile tmpfile;
+ ASSERT_TRUE(RecordCmd()->Run({"-p", pid_list, "-o", tmpfile.path}));
}
TEST(record_cmd, existing_threads) {
@@ -156,7 +181,8 @@
// Process id can also be used as thread id in linux.
std::string tid_list =
android::base::StringPrintf("%d,%d", workloads[0]->GetPid(), workloads[1]->GetPid());
- ASSERT_TRUE(RecordCmd()->Run({"-t", tid_list}));
+ TemporaryFile tmpfile;
+ ASSERT_TRUE(RecordCmd()->Run({"-t", tid_list, "-o", tmpfile.path}));
}
TEST(record_cmd, no_monitored_threads) {
@@ -164,17 +190,19 @@
}
TEST(record_cmd, more_than_one_event_types) {
- ASSERT_TRUE(RecordCmd()->Run({"-e", "cpu-cycles,cpu-clock", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"-e", "cpu-cycles", "-e", "cpu-clock", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"-e", "cpu-cycles,cpu-clock"}));
+ ASSERT_TRUE(RunRecordCmd({"-e", "cpu-cycles", "-e", "cpu-clock"}));
}
TEST(record_cmd, cpu_option) {
- ASSERT_TRUE(RecordCmd()->Run({"--cpu", "0", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"--cpu", "0", "-a", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"--cpu", "0"}));
+ if (IsRoot()) {
+ ASSERT_TRUE(RunRecordCmd({"--cpu", "0", "-a"}));
+ }
}
TEST(record_cmd, mmap_page_option) {
- ASSERT_TRUE(RecordCmd()->Run({"-m", "1", "sleep", "1"}));
- ASSERT_FALSE(RecordCmd()->Run({"-m", "0", "sleep", "1"}));
- ASSERT_FALSE(RecordCmd()->Run({"-m", "7", "sleep", "1"}));
+ ASSERT_TRUE(RunRecordCmd({"-m", "1"}));
+ ASSERT_FALSE(RunRecordCmd({"-m", "0"}));
+ ASSERT_FALSE(RunRecordCmd({"-m", "7"}));
}
diff --git a/simpleperf/cmd_report.cpp b/simpleperf/cmd_report.cpp
index b667de6..a89be67 100644
--- a/simpleperf/cmd_report.cpp
+++ b/simpleperf/cmd_report.cpp
@@ -256,7 +256,8 @@
" -i <file> Specify path of record file, default is perf.data.\n"
" -n Print the sample count for each item.\n"
" --no-demangle Don't demangle symbol names.\n"
- " --pid pid1,pid2,...\n"
+ " -o report_file_name Set report file name, default is stdout.\n"
+ " --pids pid1,pid2,...\n"
" Report only for selected pids.\n"
" --sort key1,key2,...\n"
" Select the keys to sort and print the report. Possible keys\n"
@@ -272,7 +273,8 @@
use_branch_address_(false),
accumulate_callchain_(false),
print_callgraph_(false),
- callgraph_show_callee_(true) {
+ callgraph_show_callee_(true),
+ report_fp_(nullptr) {
compare_sample_func_t compare_sample_callback = std::bind(
&ReportCommand::CompareSampleEntry, this, std::placeholders::_1, std::placeholders::_2);
sample_tree_ =
@@ -289,13 +291,15 @@
void ProcessSampleRecord(const SampleRecord& r);
bool ReadFeaturesFromRecordFile();
int CompareSampleEntry(const SampleEntry& sample1, const SampleEntry& sample2);
- void PrintReport();
+ bool PrintReport();
void PrintReportContext();
void CollectReportWidth();
void CollectReportEntryWidth(const SampleEntry& sample);
void PrintReportHeader();
void PrintReportEntry(const SampleEntry& sample);
void PrintCallGraph(const SampleEntry& sample);
+ void PrintCallGraphEntry(size_t depth, std::string prefix, const std::unique_ptr<CallChainNode>& node,
+ uint64_t parent_period, bool last);
std::string record_filename_;
std::unique_ptr<RecordFileReader> record_file_reader_;
@@ -309,6 +313,9 @@
bool accumulate_callchain_;
bool print_callgraph_;
bool callgraph_show_callee_;
+
+ std::string report_filename_;
+ FILE* report_fp_;
};
bool ReportCommand::Run(const std::vector<std::string>& args) {
@@ -332,7 +339,9 @@
ReadSampleTreeFromRecordFile();
// 3. Show collected information.
- PrintReport();
+ if (!PrintReport()) {
+ return false;
+ }
return true;
}
@@ -386,6 +395,11 @@
} else if (args[i] == "--no-demangle") {
demangle = false;
+ } else if (args[i] == "-o") {
+ if (!NextArgumentOrError(args, &i)) {
+ return false;
+ }
+ report_filename_ = args[i];
} else if (args[i] == "--pids" || args[i] == "--tids") {
if (!NextArgumentOrError(args, &i)) {
@@ -641,13 +655,29 @@
return 0;
}
-void ReportCommand::PrintReport() {
+bool ReportCommand::PrintReport() {
+ std::unique_ptr<FILE, decltype(&fclose)> file_handler(nullptr, fclose);
+ if (report_filename_.empty()) {
+ report_fp_ = stdout;
+ } else {
+ report_fp_ = fopen(report_filename_.c_str(), "w");
+ if (report_fp_ == nullptr) {
+ PLOG(ERROR) << "failed to open file " << report_filename_;
+ return false;
+ }
+ file_handler.reset(report_fp_);
+ }
PrintReportContext();
CollectReportWidth();
PrintReportHeader();
sample_tree_->VisitAllSamples(
std::bind(&ReportCommand::PrintReportEntry, this, std::placeholders::_1));
- fflush(stdout);
+ fflush(report_fp_);
+ if (ferror(report_fp_) != 0) {
+ PLOG(ERROR) << "print report failed";
+ return false;
+ }
+ return true;
}
void ReportCommand::PrintReportContext() {
@@ -660,11 +690,11 @@
android::base::StringPrintf("(type %u, config %llu)", event_attr_.type, event_attr_.config);
}
if (!record_cmdline_.empty()) {
- printf("Cmdline: %s\n", record_cmdline_.c_str());
+ fprintf(report_fp_, "Cmdline: %s\n", record_cmdline_.c_str());
}
- printf("Samples: %" PRIu64 " of event '%s'\n", sample_tree_->TotalSamples(),
- event_type_name.c_str());
- printf("Event count: %" PRIu64 "\n\n", sample_tree_->TotalPeriod());
+ fprintf(report_fp_, "Samples: %" PRIu64 " of event '%s'\n", sample_tree_->TotalSamples(),
+ event_type_name.c_str());
+ fprintf(report_fp_, "Event count: %" PRIu64 "\n\n", sample_tree_->TotalPeriod());
}
void ReportCommand::CollectReportWidth() {
@@ -682,9 +712,9 @@
for (size_t i = 0; i < displayable_items_.size(); ++i) {
auto& item = displayable_items_[i];
if (i != displayable_items_.size() - 1) {
- printf("%-*s ", static_cast<int>(item->Width()), item->Name().c_str());
+ fprintf(report_fp_, "%-*s ", static_cast<int>(item->Width()), item->Name().c_str());
} else {
- printf("%s\n", item->Name().c_str());
+ fprintf(report_fp_, "%s\n", item->Name().c_str());
}
}
}
@@ -693,9 +723,9 @@
for (size_t i = 0; i < displayable_items_.size(); ++i) {
auto& item = displayable_items_[i];
if (i != displayable_items_.size() - 1) {
- printf("%-*s ", static_cast<int>(item->Width()), item->Show(sample).c_str());
+ fprintf(report_fp_, "%-*s ", static_cast<int>(item->Width()), item->Show(sample).c_str());
} else {
- printf("%s\n", item->Show(sample).c_str());
+ fprintf(report_fp_, "%s\n", item->Show(sample).c_str());
}
}
if (print_callgraph_) {
@@ -703,15 +733,26 @@
}
}
-static void PrintCallGraphEntry(size_t depth, std::string prefix,
- const std::unique_ptr<CallChainNode>& node, uint64_t parent_period,
- bool last) {
+void ReportCommand::PrintCallGraph(const SampleEntry& sample) {
+ std::string prefix = " ";
+ fprintf(report_fp_, "%s|\n", prefix.c_str());
+ fprintf(report_fp_, "%s-- %s\n", prefix.c_str(), sample.symbol->DemangledName());
+ prefix.append(3, ' ');
+ for (size_t i = 0; i < sample.callchain.children.size(); ++i) {
+ PrintCallGraphEntry(1, prefix, sample.callchain.children[i], sample.callchain.children_period,
+ (i + 1 == sample.callchain.children.size()));
+ }
+}
+
+void ReportCommand::PrintCallGraphEntry(size_t depth, std::string prefix,
+ const std::unique_ptr<CallChainNode>& node,
+ uint64_t parent_period, bool last) {
if (depth > 20) {
LOG(WARNING) << "truncated callgraph at depth " << depth;
return;
}
prefix += "|";
- printf("%s\n", prefix.c_str());
+ fprintf(report_fp_, "%s\n", prefix.c_str());
if (last) {
prefix.back() = ' ';
}
@@ -720,10 +761,10 @@
double percentage = 100.0 * (node->period + node->children_period) / parent_period;
percentage_s = android::base::StringPrintf("--%.2lf%%-- ", percentage);
}
- printf("%s%s%s\n", prefix.c_str(), percentage_s.c_str(), node->chain[0]->symbol->DemangledName());
+ fprintf(report_fp_, "%s%s%s\n", prefix.c_str(), percentage_s.c_str(), node->chain[0]->symbol->DemangledName());
prefix.append(percentage_s.size(), ' ');
for (size_t i = 1; i < node->chain.size(); ++i) {
- printf("%s%s\n", prefix.c_str(), node->chain[i]->symbol->DemangledName());
+ fprintf(report_fp_, "%s%s\n", prefix.c_str(), node->chain[i]->symbol->DemangledName());
}
for (size_t i = 0; i < node->children.size(); ++i) {
@@ -732,17 +773,6 @@
}
}
-void ReportCommand::PrintCallGraph(const SampleEntry& sample) {
- std::string prefix = " ";
- printf("%s|\n", prefix.c_str());
- printf("%s-- %s\n", prefix.c_str(), sample.symbol->DemangledName());
- prefix.append(3, ' ');
- for (size_t i = 0; i < sample.callchain.children.size(); ++i) {
- PrintCallGraphEntry(1, prefix, sample.callchain.children[i], sample.callchain.children_period,
- (i + 1 == sample.callchain.children.size()));
- }
-}
-
-__attribute__((constructor)) static void RegisterReportCommand() {
+void RegisterReportCommand() {
RegisterCommand("report", [] { return std::unique_ptr<Command>(new ReportCommand()); });
}
diff --git a/simpleperf/cmd_report_test.cpp b/simpleperf/cmd_report_test.cpp
index 4feac19..a5ece01 100644
--- a/simpleperf/cmd_report_test.cpp
+++ b/simpleperf/cmd_report_test.cpp
@@ -16,12 +16,17 @@
#include <gtest/gtest.h>
-#include "command.h"
-#include "event_selection_set.h"
+#include <set>
+#include <unordered_map>
-static std::unique_ptr<Command> RecordCmd() {
- return CreateCommandInstance("record");
-}
+#include <android-base/file.h>
+#include <android-base/strings.h>
+#include <android-base/test_utils.h>
+
+#include "command.h"
+#include "get_test_data.h"
+#include "read_apk.h"
+#include "test_util.h"
static std::unique_ptr<Command> ReportCmd() {
return CreateCommandInstance("report");
@@ -29,76 +34,242 @@
class ReportCommandTest : public ::testing::Test {
protected:
- static void SetUpTestCase() {
- ASSERT_TRUE(RecordCmd()->Run({"-a", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"-a", "-o", "perf2.data", "sleep", "1"}));
- ASSERT_TRUE(RecordCmd()->Run({"--call-graph", "fp", "-o", "perf_g.data", "sleep", "1"}));
+ void Report(const std::string perf_data,
+ const std::vector<std::string>& add_args = std::vector<std::string>()) {
+ ReportRaw(GetTestData(perf_data), add_args);
}
+
+ void ReportRaw(const std::string perf_data,
+ const std::vector<std::string>& add_args = std::vector<std::string>()) {
+ success = false;
+ std::vector<std::string> args = {"-i", perf_data,
+ "--symfs", GetTestDataDir(), "-o", tmp_file.path};
+ args.insert(args.end(), add_args.begin(), add_args.end());
+ ASSERT_TRUE(ReportCmd()->Run(args));
+ ASSERT_TRUE(android::base::ReadFileToString(tmp_file.path, &content));
+ ASSERT_TRUE(!content.empty());
+ std::vector<std::string> raw_lines = android::base::Split(content, "\n");
+ lines.clear();
+ for (const auto& line : raw_lines) {
+ std::string s = android::base::Trim(line);
+ if (!s.empty()) {
+ lines.push_back(s);
+ }
+ }
+ ASSERT_GE(lines.size(), 2u);
+ success = true;
+ }
+
+ TemporaryFile tmp_file;
+ std::string content;
+ std::vector<std::string> lines;
+ bool success;
};
-TEST_F(ReportCommandTest, no_options) {
- ASSERT_TRUE(ReportCmd()->Run({}));
-}
-
-TEST_F(ReportCommandTest, input_file_option) {
- ASSERT_TRUE(ReportCmd()->Run({"-i", "perf2.data"}));
+TEST_F(ReportCommandTest, no_option) {
+ Report(PERF_DATA);
+ ASSERT_TRUE(success);
+ ASSERT_NE(content.find("GlobalFunc"), std::string::npos);
}
TEST_F(ReportCommandTest, sort_option_pid) {
- ASSERT_TRUE(ReportCmd()->Run({"--sort", "pid"}));
+ Report(PERF_DATA, {"--sort", "pid"});
+ ASSERT_TRUE(success);
+ size_t line_index = 0;
+ while (line_index < lines.size() && lines[line_index].find("Pid") == std::string::npos) {
+ line_index++;
+ }
+ ASSERT_LT(line_index + 2, lines.size());
}
-TEST_F(ReportCommandTest, sort_option_all) {
- ASSERT_TRUE(ReportCmd()->Run({"--sort", "comm,pid,dso,symbol"}));
+TEST_F(ReportCommandTest, sort_option_more_than_one) {
+ Report(PERF_DATA, {"--sort", "comm,pid,dso,symbol"});
+ ASSERT_TRUE(success);
+ size_t line_index = 0;
+ while (line_index < lines.size() && lines[line_index].find("Overhead") == std::string::npos) {
+ line_index++;
+ }
+ ASSERT_LT(line_index + 1, lines.size());
+ ASSERT_NE(lines[line_index].find("Command"), std::string::npos);
+ ASSERT_NE(lines[line_index].find("Pid"), std::string::npos);
+ ASSERT_NE(lines[line_index].find("Shared Object"), std::string::npos);
+ ASSERT_NE(lines[line_index].find("Symbol"), std::string::npos);
+ ASSERT_EQ(lines[line_index].find("Tid"), std::string::npos);
}
TEST_F(ReportCommandTest, children_option) {
- ASSERT_TRUE(ReportCmd()->Run({"--children", "-i", "perf_g.data"}));
+ Report(CALLGRAPH_FP_PERF_DATA, {"--children", "--sort", "symbol"});
+ ASSERT_TRUE(success);
+ std::unordered_map<std::string, std::pair<double, double>> map;
+ for (size_t i = 0; i < lines.size(); ++i) {
+ char name[1024];
+ std::pair<double, double> pair;
+ if (sscanf(lines[i].c_str(), "%lf%%%lf%%%s", &pair.first, &pair.second, name) == 3) {
+ map.insert(std::make_pair(name, pair));
+ }
+ }
+ ASSERT_NE(map.find("GlobalFunc"), map.end());
+ ASSERT_NE(map.find("main"), map.end());
+ auto func_pair = map["GlobalFunc"];
+ auto main_pair = map["main"];
+ ASSERT_GE(main_pair.first, func_pair.first);
+ ASSERT_GE(func_pair.first, func_pair.second);
+ ASSERT_GE(func_pair.second, main_pair.second);
+}
+
+static bool CheckCalleeMode(std::vector<std::string>& lines) {
+ bool found = false;
+ for (size_t i = 0; i + 2 < lines.size(); ++i) {
+ if (lines[i].find("GlobalFunc") != std::string::npos &&
+ lines[i + 1].find("|") != std::string::npos &&
+ lines[i + 2].find("main") != std::string::npos) {
+ found = true;
+ break;
+ }
+ }
+ return found;
+}
+
+static bool CheckCallerMode(std::vector<std::string>& lines) {
+ bool found = false;
+ for (size_t i = 0; i + 2 < lines.size(); ++i) {
+ if (lines[i].find("main") != std::string::npos &&
+ lines[i + 1].find("|") != std::string::npos &&
+ lines[i + 2].find("GlobalFunc") != std::string::npos) {
+ found = true;
+ break;
+ }
+ }
+ return found;
}
TEST_F(ReportCommandTest, callgraph_option) {
- ASSERT_TRUE(ReportCmd()->Run({"-g", "-i", "perf_g.data"}));
- ASSERT_TRUE(ReportCmd()->Run({"-g", "callee", "-i", "perf_g.data"}));
- ASSERT_TRUE(ReportCmd()->Run({"-g", "caller", "-i", "perf_g.data"}));
+ Report(CALLGRAPH_FP_PERF_DATA, {"-g"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(CheckCalleeMode(lines));
+ Report(CALLGRAPH_FP_PERF_DATA, {"-g", "callee"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(CheckCalleeMode(lines));
+ Report(CALLGRAPH_FP_PERF_DATA, {"-g", "caller"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(CheckCallerMode(lines));
+}
+
+static bool AllItemsWithString(std::vector<std::string>& lines, const std::vector<std::string>& strs) {
+ size_t line_index = 0;
+ while (line_index < lines.size() && lines[line_index].find("Overhead") == std::string::npos) {
+ line_index++;
+ }
+ if (line_index == lines.size() || line_index + 1 == lines.size()) {
+ return false;
+ }
+ line_index++;
+ for (; line_index < lines.size(); ++line_index) {
+ bool exist = false;
+ for (auto& s : strs) {
+ if (lines[line_index].find(s) != std::string::npos) {
+ exist = true;
+ break;
+ }
+ }
+ if (!exist) {
+ return false;
+ }
+ }
+ return true;
}
TEST_F(ReportCommandTest, pid_filter_option) {
- ASSERT_TRUE(ReportCmd()->Run({"--pids", "0"}));
- ASSERT_TRUE(ReportCmd()->Run({"--pids", "0,1"}));
+ Report(PERF_DATA);
+ ASSERT_TRUE("success");
+ ASSERT_FALSE(AllItemsWithString(lines, {"26083"}));
+ ASSERT_FALSE(AllItemsWithString(lines, {"26083", "26090"}));
+ Report(PERF_DATA, {"--pids", "26083"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"26083"}));
+ Report(PERF_DATA, {"--pids", "26083,26090"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"26083", "26090"}));
}
TEST_F(ReportCommandTest, tid_filter_option) {
- ASSERT_TRUE(ReportCmd()->Run({"--tids", "0"}));
- ASSERT_TRUE(ReportCmd()->Run({"--tids", "0,1"}));
+ Report(PERF_DATA);
+ ASSERT_TRUE("success");
+ ASSERT_FALSE(AllItemsWithString(lines, {"26083"}));
+ ASSERT_FALSE(AllItemsWithString(lines, {"26083", "26090"}));
+ Report(PERF_DATA, {"--tids", "26083"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"26083"}));
+ Report(PERF_DATA, {"--tids", "26083,26090"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"26083", "26090"}));
}
TEST_F(ReportCommandTest, comm_filter_option) {
- ASSERT_TRUE(ReportCmd()->Run({"--comms", "swapper"}));
- ASSERT_TRUE(ReportCmd()->Run({"--comms", "swapper,simpleperf"}));
+ Report(PERF_DATA, {"--sort", "comm"});
+ ASSERT_TRUE(success);
+ ASSERT_FALSE(AllItemsWithString(lines, {"t1"}));
+ ASSERT_FALSE(AllItemsWithString(lines, {"t1", "t2"}));
+ Report(PERF_DATA, {"--sort", "comm", "--comms", "t1"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"t1"}));
+ Report(PERF_DATA, {"--sort", "comm", "--comms", "t1,t2"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"t1", "t2"}));
}
TEST_F(ReportCommandTest, dso_filter_option) {
- ASSERT_TRUE(ReportCmd()->Run({"--dsos", "[kernel.kallsyms]"}));
- ASSERT_TRUE(ReportCmd()->Run({"--dsos", "[kernel.kallsyms],/init"}));
+ Report(PERF_DATA, {"--sort", "dso"});
+ ASSERT_TRUE(success);
+ ASSERT_FALSE(AllItemsWithString(lines, {"/t1"}));
+ ASSERT_FALSE(AllItemsWithString(lines, {"/t1", "/t2"}));
+ Report(PERF_DATA, {"--sort", "dso", "--dsos", "/t1"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"/t1"}));
+ Report(PERF_DATA, {"--sort", "dso", "--dsos", "/t1,/t2"});
+ ASSERT_TRUE(success);
+ ASSERT_TRUE(AllItemsWithString(lines, {"/t1", "/t2"}));
}
-TEST(report_cmd, use_branch_address) {
- if (IsBranchSamplingSupported()) {
- ASSERT_TRUE(RecordCmd()->Run({"-b", "sleep", "1"}));
- ASSERT_TRUE(
- ReportCmd()->Run({"-b", "--sort", "comm,pid,dso_from,symbol_from,dso_to,symbol_to"}));
- } else {
- GTEST_LOG_(INFO)
- << "This test does nothing as branch stack sampling is not supported on this device.";
+TEST_F(ReportCommandTest, use_branch_address) {
+ Report(BRANCH_PERF_DATA, {"-b", "--sort", "symbol_from,symbol_to"});
+ std::set<std::pair<std::string, std::string>> hit_set;
+ bool after_overhead = false;
+ for (const auto& line : lines) {
+ if (!after_overhead && line.find("Overhead") != std::string::npos) {
+ after_overhead = true;
+ } else if (after_overhead) {
+ char from[80];
+ char to[80];
+ if (sscanf(line.c_str(), "%*f%%%s%s", from, to) == 2) {
+ hit_set.insert(std::make_pair<std::string, std::string>(from, to));
+ }
+ }
}
+ ASSERT_NE(hit_set.find(std::make_pair<std::string, std::string>("GlobalFunc", "CalledFunc")),
+ hit_set.end());
+ ASSERT_NE(hit_set.find(std::make_pair<std::string, std::string>("CalledFunc", "GlobalFunc")),
+ hit_set.end());
}
-TEST(report_cmd, dwarf_callgraph) {
- if (IsDwarfCallChainSamplingSupported()) {
- ASSERT_TRUE(RecordCmd()->Run({"-g", "-o", "perf_dwarf.data", "sleep", "1"}));
- ASSERT_TRUE(ReportCmd()->Run({"-g", "-i", "perf_dwarf.data"}));
- } else {
- GTEST_LOG_(INFO)
- << "This test does nothing as dwarf callchain sampling is not supported on this device.";
- }
+#if defined(__ANDROID__) || defined(__linux__)
+
+static std::unique_ptr<Command> RecordCmd() {
+ return CreateCommandInstance("record");
+}
+
+TEST_F(ReportCommandTest, dwarf_callgraph) {
+ TemporaryFile tmp_file;
+ ASSERT_TRUE(RecordCmd()->Run({"-g", "-o", tmp_file.path, "sleep", SLEEP_SEC}));
+ ReportRaw(tmp_file.path, {"-g"});
+ ASSERT_TRUE(success);
+}
+
+#endif
+
+TEST_F(ReportCommandTest, report_symbols_of_nativelib_in_apk) {
+ Report(NATIVELIB_IN_APK_PERF_DATA);
+ ASSERT_TRUE(success);
+ ASSERT_NE(content.find(GetUrlInApk(APK_FILE, NATIVELIB_IN_APK)), std::string::npos);
+ ASSERT_NE(content.find("GlobalFunc"), std::string::npos);
}
diff --git a/simpleperf/cmd_stat.cpp b/simpleperf/cmd_stat.cpp
index 76ed145..228b4ed 100644
--- a/simpleperf/cmd_stat.cpp
+++ b/simpleperf/cmd_stat.cpp
@@ -411,6 +411,6 @@
return true;
}
-__attribute__((constructor)) static void RegisterStatCommand() {
+void RegisterStatCommand() {
RegisterCommand("stat", [] { return std::unique_ptr<Command>(new StatCommand); });
}
diff --git a/simpleperf/cmd_stat_test.cpp b/simpleperf/cmd_stat_test.cpp
index 3444806..27f1f09 100644
--- a/simpleperf/cmd_stat_test.cpp
+++ b/simpleperf/cmd_stat_test.cpp
@@ -19,6 +19,7 @@
#include <android-base/stringprintf.h>
#include "command.h"
+#include "get_test_data.h"
#include "test_util.h"
static std::unique_ptr<Command> StatCmd() {
@@ -34,7 +35,9 @@
}
TEST(stat_cmd, system_wide_option) {
- ASSERT_TRUE(StatCmd()->Run({"-a", "sleep", "1"}));
+ if (IsRoot()) {
+ ASSERT_TRUE(StatCmd()->Run({"-a", "sleep", "1"}));
+ }
}
TEST(stat_cmd, verbose_option) {
@@ -42,11 +45,23 @@
}
TEST(stat_cmd, tracepoint_event) {
- ASSERT_TRUE(StatCmd()->Run({"-a", "-e", "sched:sched_switch", "sleep", "1"}));
+ if (IsRoot()) {
+ ASSERT_TRUE(StatCmd()->Run({"-a", "-e", "sched:sched_switch", "sleep", "1"}));
+ }
}
TEST(stat_cmd, event_modifier) {
- ASSERT_TRUE(StatCmd()->Run({"-e", "cpu-cycles:u,sched:sched_switch:k", "sleep", "1"}));
+ ASSERT_TRUE(StatCmd()->Run({"-e", "cpu-cycles:u,cpu-cycles:k", "sleep", "1"}));
+}
+
+void CreateProcesses(size_t count, std::vector<std::unique_ptr<Workload>>* workloads) {
+ workloads->clear();
+ for (size_t i = 0; i < count; ++i) {
+ auto workload = Workload::CreateWorkload({"sleep", "1"});
+ ASSERT_TRUE(workload != nullptr);
+ ASSERT_TRUE(workload->Start());
+ workloads->push_back(std::move(workload));
+ }
}
TEST(stat_cmd, existing_processes) {
@@ -72,5 +87,7 @@
TEST(stat_cmd, cpu_option) {
ASSERT_TRUE(StatCmd()->Run({"--cpu", "0", "sleep", "1"}));
- ASSERT_TRUE(StatCmd()->Run({"--cpu", "0", "-a", "sleep", "1"}));
+ if (IsRoot()) {
+ ASSERT_TRUE(StatCmd()->Run({"--cpu", "0", "-a", "sleep", "1"}));
+ }
}
diff --git a/simpleperf/command.cpp b/simpleperf/command.cpp
index d4cfd65..3416653 100644
--- a/simpleperf/command.cpp
+++ b/simpleperf/command.cpp
@@ -68,3 +68,26 @@
}
return names;
}
+
+extern void RegisterDumpRecordCommand();
+extern void RegisterHelpCommand();
+extern void RegisterListCommand();
+extern void RegisterRecordCommand();
+extern void RegisterReportCommand();
+extern void RegisterStatCommand();
+
+class CommandRegister {
+ public:
+ CommandRegister() {
+ RegisterDumpRecordCommand();
+ RegisterHelpCommand();
+ RegisterReportCommand();
+#if defined(__linux__)
+ RegisterListCommand();
+ RegisterRecordCommand();
+ RegisterStatCommand();
+#endif
+ }
+};
+
+CommandRegister command_register;
diff --git a/simpleperf/dso.cpp b/simpleperf/dso.cpp
index 6397eaa..9c33667 100644
--- a/simpleperf/dso.cpp
+++ b/simpleperf/dso.cpp
@@ -26,6 +26,7 @@
#include <android-base/logging.h>
#include "environment.h"
+#include "read_apk.h"
#include "read_elf.h"
#include "utils.h"
@@ -199,9 +200,14 @@
case DSO_KERNEL_MODULE:
result = LoadKernelModule();
break;
- case DSO_ELF_FILE:
- result = LoadElfFile();
+ case DSO_ELF_FILE: {
+ if (std::get<0>(SplitUrlInApk(path_))) {
+ result = LoadEmbeddedElfFile();
+ } else {
+ result = LoadElfFile();
+ }
break;
+ }
}
if (result) {
std::sort(symbols_.begin(), symbols_.end(), SymbolComparator());
@@ -305,6 +311,16 @@
return loaded;
}
+bool Dso::LoadEmbeddedElfFile() {
+ std::string path = GetAccessiblePath();
+ BuildId build_id = GetExpectedBuildId(path);
+ auto tuple = SplitUrlInApk(path);
+ CHECK(std::get<0>(tuple));
+ return ParseSymbolsFromApkFile(std::get<1>(tuple), std::get<2>(tuple), build_id,
+ std::bind(ElfFileSymbolCallback, std::placeholders::_1,
+ this, SymbolFilterForDso));
+}
+
void Dso::InsertSymbol(const Symbol& symbol) {
symbols_.push_back(symbol);
}
diff --git a/simpleperf/dso.h b/simpleperf/dso.h
index a140e5e..9697319 100644
--- a/simpleperf/dso.h
+++ b/simpleperf/dso.h
@@ -93,6 +93,7 @@
bool LoadKernel();
bool LoadKernelModule();
bool LoadElfFile();
+ bool LoadEmbeddedElfFile();
void InsertSymbol(const Symbol& symbol);
void FixupSymbolLength();
diff --git a/simpleperf/environment_test.cpp b/simpleperf/environment_test.cpp
index 9a96530..6bca7b8 100644
--- a/simpleperf/environment_test.cpp
+++ b/simpleperf/environment_test.cpp
@@ -18,6 +18,7 @@
#include <functional>
#include <android-base/file.h>
+#include <android-base/test_utils.h>
#include "environment.h"
@@ -28,10 +29,21 @@
ASSERT_EQ(GetCpusFromString("1,0-3,3,4"), std::vector<int>({0, 1, 2, 3, 4}));
}
-static bool FindKernelSymbol(const KernelSymbol& sym1, const KernelSymbol& sym2) {
- return sym1.addr == sym2.addr && sym1.type == sym2.type && strcmp(sym1.name, sym2.name) == 0 &&
- ((sym1.module == nullptr && sym2.module == nullptr) ||
- (strcmp(sym1.module, sym2.module) == 0));
+static bool ModulesMatch(const char* p, const char* q) {
+ if (p == nullptr && q == nullptr) {
+ return true;
+ }
+ if (p != nullptr && q != nullptr) {
+ return strcmp(p, q) == 0;
+ }
+ return false;
+}
+
+static bool KernelSymbolsMatch(const KernelSymbol& sym1, const KernelSymbol& sym2) {
+ return sym1.addr == sym2.addr &&
+ sym1.type == sym2.type &&
+ strcmp(sym1.name, sym2.name) == 0 &&
+ ModulesMatch(sym1.module, sym2.module);
}
TEST(environment, ProcessKernelSymbols) {
@@ -39,25 +51,24 @@
"ffffffffa005c4e4 d __warned.41698 [libsas]\n"
"aaaaaaaaaaaaaaaa T _text\n"
"cccccccccccccccc c ccccc\n";
- const char* tempfile = "tempfile_process_kernel_symbols";
- ASSERT_TRUE(android::base::WriteStringToFile(data, tempfile));
+ TemporaryFile tempfile;
+ ASSERT_TRUE(android::base::WriteStringToFile(data, tempfile.path));
KernelSymbol expected_symbol;
expected_symbol.addr = 0xffffffffa005c4e4ULL;
expected_symbol.type = 'd';
expected_symbol.name = "__warned.41698";
expected_symbol.module = "libsas";
ASSERT_TRUE(ProcessKernelSymbols(
- tempfile, std::bind(&FindKernelSymbol, std::placeholders::_1, expected_symbol)));
+ tempfile.path, std::bind(&KernelSymbolsMatch, std::placeholders::_1, expected_symbol)));
expected_symbol.addr = 0xaaaaaaaaaaaaaaaaULL;
expected_symbol.type = 'T';
expected_symbol.name = "_text";
expected_symbol.module = nullptr;
ASSERT_TRUE(ProcessKernelSymbols(
- tempfile, std::bind(&FindKernelSymbol, std::placeholders::_1, expected_symbol)));
+ tempfile.path, std::bind(&KernelSymbolsMatch, std::placeholders::_1, expected_symbol)));
expected_symbol.name = "non_existent_symbol";
ASSERT_FALSE(ProcessKernelSymbols(
- tempfile, std::bind(&FindKernelSymbol, std::placeholders::_1, expected_symbol)));
- ASSERT_EQ(0, unlink(tempfile));
+ tempfile.path, std::bind(&KernelSymbolsMatch, std::placeholders::_1, expected_symbol)));
}
diff --git a/simpleperf/event_selection_set.cpp b/simpleperf/event_selection_set.cpp
index 038e577..df731f1 100644
--- a/simpleperf/event_selection_set.cpp
+++ b/simpleperf/event_selection_set.cpp
@@ -211,6 +211,7 @@
for (auto& cpu : cpus) {
auto event_fd = EventFd::OpenEventFile(selection.event_attr, tid, cpu);
if (event_fd != nullptr) {
+ LOG(VERBOSE) << "OpenEventFile for tid " << tid << ", cpu " << cpu;
selection.event_fds.push_back(std::move(event_fd));
++open_per_thread;
}
diff --git a/simpleperf/get_test_data.h b/simpleperf/get_test_data.h
index 313da04..02363d6 100644
--- a/simpleperf/get_test_data.h
+++ b/simpleperf/get_test_data.h
@@ -19,6 +19,28 @@
#include <string>
+#include "build_id.h"
+
std::string GetTestData(const std::string& filename);
+const std::string& GetTestDataDir();
+
+bool IsRoot();
+
+static const std::string PERF_DATA = "perf.data";
+static const std::string CALLGRAPH_FP_PERF_DATA = "perf_g_fp.data";
+static const std::string BRANCH_PERF_DATA = "perf_b.data";
+
+static const std::string ELF_FILE = "elf";
+
+static const std::string APK_FILE = "data/app/com.example.hellojni-1/base.apk";
+static const std::string NATIVELIB_IN_APK = "lib/arm64-v8a/libhello-jni.so";
+static const std::string NATIVELIB_IN_APK_PERF_DATA = "has_embedded_native_libs_apk_perf.data";
+
+constexpr size_t NATIVELIB_OFFSET_IN_APK = 0x8000;
+constexpr size_t NATIVELIB_SIZE_IN_APK = 0x15d8;
+
+static BuildId elf_file_build_id("0b12a384a9f4a3f3659b7171ca615dbec3a81f71");
+
+static BuildId native_lib_build_id("b46f51cb9c4b71fb08a2fdbefc2c187894f14008");
#endif // SIMPLE_PERF_GET_TEST_DATA_H_
diff --git a/simpleperf/gtest_main.cpp b/simpleperf/gtest_main.cpp
index 04412ea..8ff45c8 100644
--- a/simpleperf/gtest_main.cpp
+++ b/simpleperf/gtest_main.cpp
@@ -16,32 +16,131 @@
#include <gtest/gtest.h>
+#include <memory>
+
+#include <android-base/file.h>
#include <android-base/logging.h>
+#include <android-base/test_utils.h>
+#include <ziparchive/zip_archive.h>
#include "get_test_data.h"
+#include "read_elf.h"
#include "utils.h"
static std::string testdata_dir;
+#if defined(IN_CTS_TEST)
+static const std::string testdata_section = ".testzipdata";
+
+static bool ExtractTestDataFromElfSection() {
+ if (!MkdirWithParents(testdata_dir)) {
+ PLOG(ERROR) << "failed to create testdata_dir " << testdata_dir;
+ return false;
+ }
+ std::string content;
+ if (!ReadSectionFromElfFile("/proc/self/exe", testdata_section, &content)) {
+ LOG(ERROR) << "failed to read section " << testdata_section;
+ return false;
+ }
+ TemporaryFile tmp_file;
+ if (!android::base::WriteStringToFile(content, tmp_file.path)) {
+ PLOG(ERROR) << "failed to write file " << tmp_file.path;
+ return false;
+ }
+ ArchiveHelper ahelper(tmp_file.fd, tmp_file.path);
+ if (!ahelper) {
+ LOG(ERROR) << "failed to open archive " << tmp_file.path;
+ return false;
+ }
+ ZipArchiveHandle& handle = ahelper.archive_handle();
+ void* cookie;
+ int ret = StartIteration(handle, &cookie, nullptr, nullptr);
+ if (ret != 0) {
+ LOG(ERROR) << "failed to start iterating zip entries";
+ return false;
+ }
+ ZipEntry entry;
+ ZipString name;
+ while (Next(cookie, &entry, &name) == 0) {
+ std::string entry_name(name.name, name.name + name.name_length);
+ std::string path = testdata_dir + entry_name;
+ // Skip dir.
+ if (path.back() == '/') {
+ continue;
+ }
+ if (!MkdirWithParents(path)) {
+ LOG(ERROR) << "failed to create dir for " << path;
+ return false;
+ }
+ FileHelper fhelper = FileHelper::OpenWriteOnly(path);
+ if (!fhelper) {
+ PLOG(ERROR) << "failed to create file " << path;
+ return false;
+ }
+ std::vector<uint8_t> data(entry.uncompressed_length);
+ if (ExtractToMemory(handle, &entry, data.data(), data.size()) != 0) {
+ LOG(ERROR) << "failed to extract entry " << entry_name;
+ return false;
+ }
+ if (!android::base::WriteFully(fhelper.fd(), data.data(), data.size())) {
+ LOG(ERROR) << "failed to write file " << path;
+ return false;
+ }
+ }
+ EndIteration(cookie);
+ return true;
+}
+#endif // defined(IN_CTS_TEST)
+
int main(int argc, char** argv) {
InitLogging(argv, android::base::StderrLogger);
testing::InitGoogleTest(&argc, argv);
+
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "-t") == 0 && i + 1 < argc) {
testdata_dir = argv[i + 1];
- break;
+ i++;
}
}
- if (testdata_dir.empty()) {
- printf("Usage: simpleperf_unit_test -t <testdata_dir>\n");
+
+#if defined(IN_CTS_TEST)
+ std::unique_ptr<TemporaryDir> tmp_dir;
+ if (!::testing::GTEST_FLAG(list_tests) && testdata_dir.empty()) {
+ tmp_dir.reset(new TemporaryDir);
+ testdata_dir = std::string(tmp_dir->path) + "/";
+ if (!ExtractTestDataFromElfSection()) {
+ LOG(ERROR) << "failed to extract test data from elf section";
+ return 1;
+ }
+ }
+#endif
+ if (!::testing::GTEST_FLAG(list_tests) && testdata_dir.empty()) {
+ printf("Usage: %s -t <testdata_dir>\n", argv[0]);
return 1;
}
if (testdata_dir.back() != '/') {
testdata_dir.push_back('/');
}
+ LOG(INFO) << "testdata is in " << testdata_dir;
return RUN_ALL_TESTS();
}
std::string GetTestData(const std::string& filename) {
return testdata_dir + filename;
}
+
+const std::string& GetTestDataDir() {
+ return testdata_dir;
+}
+
+bool IsRoot() {
+ static int is_root = -1;
+ if (is_root == -1) {
+#if defined(__linux__)
+ is_root = (getuid() == 0) ? 1 : 0;
+#else
+ is_root = 0;
+#endif
+ }
+ return is_root == 1;
+}
diff --git a/simpleperf/main.cpp b/simpleperf/main.cpp
index 7fdae42..75313a5 100644
--- a/simpleperf/main.cpp
+++ b/simpleperf/main.cpp
@@ -24,6 +24,7 @@
#include "command.h"
static std::map<std::string, android::base::LogSeverity> log_severity_map = {
+ {"verbose", android::base::VERBOSE},
{"debug", android::base::DEBUG},
{"warning", android::base::WARNING},
{"error", android::base::ERROR},
@@ -35,33 +36,32 @@
std::vector<std::string> args;
android::base::LogSeverity log_severity = android::base::WARNING;
- if (argc == 1) {
- args.push_back("help");
- } else {
- for (int i = 1; i < argc; ++i) {
- if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
- args.insert(args.begin(), "help");
- } else if (strcmp(argv[i], "--log") == 0) {
- if (i + 1 < argc) {
- ++i;
- auto it = log_severity_map.find(argv[i]);
- if (it != log_severity_map.end()) {
- log_severity = it->second;
- } else {
- LOG(ERROR) << "Unknown log severity: " << argv[i];
- return 1;
- }
+ for (int i = 1; i < argc; ++i) {
+ if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
+ args.insert(args.begin(), "help");
+ } else if (strcmp(argv[i], "--log") == 0) {
+ if (i + 1 < argc) {
+ ++i;
+ auto it = log_severity_map.find(argv[i]);
+ if (it != log_severity_map.end()) {
+ log_severity = it->second;
} else {
- LOG(ERROR) << "Missing argument for --log option.\n";
+ LOG(ERROR) << "Unknown log severity: " << argv[i];
return 1;
}
} else {
- args.push_back(argv[i]);
+ LOG(ERROR) << "Missing argument for --log option.\n";
+ return 1;
}
+ } else {
+ args.push_back(argv[i]);
}
}
android::base::ScopedLogSeverity severity(log_severity);
+ if (args.empty()) {
+ args.push_back("help");
+ }
std::unique_ptr<Command> command = CreateCommandInstance(args[0]);
if (command == nullptr) {
LOG(ERROR) << "malformed command line: unknown command " << args[0];
diff --git a/simpleperf/read_apk.cpp b/simpleperf/read_apk.cpp
index 270b305..18c76c2 100644
--- a/simpleperf/read_apk.cpp
+++ b/simpleperf/read_apk.cpp
@@ -24,8 +24,6 @@
#include <sys/types.h>
#include <unistd.h>
-#include <map>
-
#include <android-base/file.h>
#include <android-base/logging.h>
#include <ziparchive/zip_archive.h>
@@ -33,82 +31,34 @@
#include "read_elf.h"
#include "utils.h"
-bool IsValidJarOrApkPath(const std::string& filename) {
- static const char zip_preamble[] = {0x50, 0x4b, 0x03, 0x04 };
- if (!IsRegularFile(filename)) {
- return false;
+std::map<ApkInspector::ApkOffset, std::unique_ptr<EmbeddedElf>> ApkInspector::embedded_elf_cache_;
+
+EmbeddedElf* ApkInspector::FindElfInApkByOffset(const std::string& apk_path, uint64_t file_offset) {
+ // Already in cache?
+ ApkOffset ami(apk_path, file_offset);
+ auto it = embedded_elf_cache_.find(ami);
+ if (it != embedded_elf_cache_.end()) {
+ return it->second.get();
}
- std::string mode = std::string("rb") + CLOSE_ON_EXEC_MODE;
- FILE* fp = fopen(filename.c_str(), mode.c_str());
- if (fp == nullptr) {
- return false;
- }
- char buf[4];
- if (fread(buf, 4, 1, fp) != 1) {
- fclose(fp);
- return false;
- }
- fclose(fp);
- return memcmp(buf, zip_preamble, 4) == 0;
+ std::unique_ptr<EmbeddedElf> elf = FindElfInApkByOffsetWithoutCache(apk_path, file_offset);
+ EmbeddedElf* result = elf.get();
+ embedded_elf_cache_[ami] = std::move(elf);
+ return result;
}
-class ArchiveHelper {
- public:
- explicit ArchiveHelper(int fd) : valid_(false) {
- int rc = OpenArchiveFd(fd, "", &handle_, false);
- if (rc == 0) {
- valid_ = true;
- }
- }
- ~ArchiveHelper() {
- if (valid_) {
- CloseArchive(handle_);
- }
- }
- bool valid() const { return valid_; }
- ZipArchiveHandle &archive_handle() { return handle_; }
-
- private:
- ZipArchiveHandle handle_;
- bool valid_;
-};
-
-// First component of pair is APK file path, second is offset into APK
-typedef std::pair<std::string, size_t> ApkOffset;
-
-class ApkInspectorImpl {
- public:
- EmbeddedElf *FindElfInApkByMmapOffset(const std::string& apk_path,
- size_t mmap_offset);
- private:
- std::vector<EmbeddedElf> embedded_elf_files_;
- // Value is either 0 (no elf) or 1-based slot in array above.
- std::map<ApkOffset, uint32_t> cache_;
-};
-
-EmbeddedElf *ApkInspectorImpl::FindElfInApkByMmapOffset(const std::string& apk_path,
- size_t mmap_offset)
-{
- // Already in cache?
- ApkOffset ami(apk_path, mmap_offset);
- auto it = cache_.find(ami);
- if (it != cache_.end()) {
- uint32_t idx = it->second;
- return (idx ? &embedded_elf_files_[idx-1] : nullptr);
- }
- cache_[ami] = 0u;
-
+std::unique_ptr<EmbeddedElf> ApkInspector::FindElfInApkByOffsetWithoutCache(const std::string& apk_path,
+ uint64_t file_offset) {
// Crack open the apk(zip) file and take a look.
- if (! IsValidJarOrApkPath(apk_path)) {
+ if (!IsValidApkPath(apk_path)) {
return nullptr;
}
- FileHelper fhelper(apk_path.c_str());
- if (fhelper.fd() == -1) {
+ FileHelper fhelper = FileHelper::OpenReadOnly(apk_path);
+ if (!fhelper) {
return nullptr;
}
- ArchiveHelper ahelper(fhelper.fd());
- if (!ahelper.valid()) {
+ ArchiveHelper ahelper(fhelper.fd(), apk_path);
+ if (!ahelper) {
return nullptr;
}
ZipArchiveHandle &handle = ahelper.archive_handle();
@@ -124,11 +74,10 @@
ZipString zname;
bool found = false;
int zrc;
- off64_t mmap_off64 = mmap_offset;
while ((zrc = Next(iteration_cookie, &zentry, &zname)) == 0) {
if (zentry.method == kCompressStored &&
- mmap_off64 >= zentry.offset &&
- mmap_off64 < zentry.offset + zentry.uncompressed_length) {
+ file_offset >= static_cast<uint64_t>(zentry.offset) &&
+ file_offset < static_cast<uint64_t>(zentry.offset + zentry.uncompressed_length)) {
// Found.
found = true;
break;
@@ -154,26 +103,87 @@
}
// Elf found: add EmbeddedElf to vector, update cache.
- EmbeddedElf ee(apk_path, entry_name, zentry.offset, zentry.uncompressed_length);
- embedded_elf_files_.push_back(ee);
- unsigned idx = embedded_elf_files_.size();
- cache_[ami] = idx;
- return &embedded_elf_files_[idx-1];
+ return std::unique_ptr<EmbeddedElf>(new EmbeddedElf(apk_path, entry_name, zentry.offset,
+ zentry.uncompressed_length));
}
-// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
-
-ApkInspector::ApkInspector()
- : impl_(new ApkInspectorImpl())
-{
+std::unique_ptr<EmbeddedElf> ApkInspector::FindElfInApkByName(const std::string& apk_path,
+ const std::string& elf_filename) {
+ if (!IsValidApkPath(apk_path)) {
+ return nullptr;
+ }
+ FileHelper fhelper = FileHelper::OpenReadOnly(apk_path);
+ if (!fhelper) {
+ return nullptr;
+ }
+ ArchiveHelper ahelper(fhelper.fd(), apk_path);
+ if (!ahelper) {
+ return nullptr;
+ }
+ ZipArchiveHandle& handle = ahelper.archive_handle();
+ ZipEntry zentry;
+ int32_t rc = FindEntry(handle, ZipString(elf_filename.c_str()), &zentry);
+ if (rc != 0) {
+ LOG(ERROR) << "failed to find " << elf_filename << " in " << apk_path
+ << ": " << ErrorCodeString(rc);
+ return nullptr;
+ }
+ if (zentry.method != kCompressStored || zentry.compressed_length != zentry.uncompressed_length) {
+ LOG(ERROR) << "shared library " << elf_filename << " in " << apk_path << " is compressed";
+ return nullptr;
+ }
+ return std::unique_ptr<EmbeddedElf>(new EmbeddedElf(apk_path, elf_filename, zentry.offset,
+ zentry.uncompressed_length));
}
-ApkInspector::~ApkInspector()
-{
+bool IsValidApkPath(const std::string& apk_path) {
+ static const char zip_preamble[] = {0x50, 0x4b, 0x03, 0x04 };
+ if (!IsRegularFile(apk_path)) {
+ return false;
+ }
+ std::string mode = std::string("rb") + CLOSE_ON_EXEC_MODE;
+ FILE* fp = fopen(apk_path.c_str(), mode.c_str());
+ if (fp == nullptr) {
+ return false;
+ }
+ char buf[4];
+ if (fread(buf, 4, 1, fp) != 1) {
+ fclose(fp);
+ return false;
+ }
+ fclose(fp);
+ return memcmp(buf, zip_preamble, 4) == 0;
}
-EmbeddedElf *ApkInspector::FindElfInApkByMmapOffset(const std::string& apk_path,
- size_t mmap_offset)
-{
- return impl_->FindElfInApkByMmapOffset(apk_path, mmap_offset);
+// Refer file in apk in compliance with http://developer.android.com/reference/java/net/JarURLConnection.html.
+std::string GetUrlInApk(const std::string& apk_path, const std::string& elf_filename) {
+ return apk_path + "!/" + elf_filename;
+}
+
+std::tuple<bool, std::string, std::string> SplitUrlInApk(const std::string& path) {
+ size_t pos = path.find("!/");
+ if (pos == std::string::npos) {
+ return std::make_tuple(false, "", "");
+ }
+ return std::make_tuple(true, path.substr(0, pos), path.substr(pos + 2));
+}
+
+bool GetBuildIdFromApkFile(const std::string& apk_path, const std::string& elf_filename,
+ BuildId* build_id) {
+ std::unique_ptr<EmbeddedElf> ee = ApkInspector::FindElfInApkByName(apk_path, elf_filename);
+ if (ee == nullptr) {
+ return false;
+ }
+ return GetBuildIdFromEmbeddedElfFile(apk_path, ee->entry_offset(), ee->entry_size(), build_id);
+}
+
+bool ParseSymbolsFromApkFile(const std::string& apk_path, const std::string& elf_filename,
+ const BuildId& expected_build_id,
+ std::function<void(const ElfFileSymbol&)> callback) {
+ std::unique_ptr<EmbeddedElf> ee = ApkInspector::FindElfInApkByName(apk_path, elf_filename);
+ if (ee == nullptr) {
+ return false;
+ }
+ return ParseSymbolsFromEmbeddedElfFile(apk_path, ee->entry_offset(), ee->entry_size(),
+ expected_build_id, callback);
}
diff --git a/simpleperf/read_apk.h b/simpleperf/read_apk.h
index c35cac7..82531f4 100644
--- a/simpleperf/read_apk.h
+++ b/simpleperf/read_apk.h
@@ -17,13 +17,14 @@
#ifndef SIMPLE_PERF_READ_APK_H_
#define SIMPLE_PERF_READ_APK_H_
+#include <stdint.h>
+
+#include <map>
#include <memory>
#include <string>
+#include <tuple>
-#include <android-base/file.h>
-
-// Exposed for unit testing
-bool IsValidJarOrApkPath(const std::string& filename);
+#include "read_elf.h"
// Container for info an on ELF file embedded into an APK file
class EmbeddedElf {
@@ -52,7 +53,7 @@
const std::string &entry_name() const { return entry_name_; }
// Offset of zip entry from start of containing APK file
- size_t entry_offset() const { return entry_offset_; }
+ uint64_t entry_offset() const { return entry_offset_; }
// Size of zip entry (length of embedded ELF)
uint32_t entry_size() const { return entry_size_; }
@@ -60,40 +61,42 @@
private:
std::string filepath_; // containing APK path
std::string entry_name_; // name of entry in zip index of embedded elf file
- size_t entry_offset_; // offset of ELF from start of containing APK file
+ uint64_t entry_offset_; // offset of ELF from start of containing APK file
uint32_t entry_size_; // size of ELF file in zip
};
-struct EmbeddedElfComparator {
- bool operator()(const EmbeddedElf& ee1, const EmbeddedElf& ee2) {
- int res1 = ee1.filepath().compare(ee2.filepath());
- if (res1 != 0) {
- return res1 < 0;
- }
- int res2 = ee1.entry_name().compare(ee2.entry_name());
- if (res2 != 0) {
- return res2 < 0;
- }
- return ee1.entry_offset() < ee2.entry_offset();
- }
-};
-
-class ApkInspectorImpl;
-
// APK inspector helper class
class ApkInspector {
public:
- ApkInspector();
- ~ApkInspector();
-
// Given an APK/ZIP/JAR file and an offset into that file, if the
// corresponding region of the APK corresponds to an uncompressed
// ELF file, then return pertinent info on the ELF.
- EmbeddedElf *FindElfInApkByMmapOffset(const std::string& apk_path,
- size_t mmap_offset);
+ static EmbeddedElf* FindElfInApkByOffset(const std::string& apk_path, uint64_t file_offset);
+ static std::unique_ptr<EmbeddedElf> FindElfInApkByName(const std::string& apk_path,
+ const std::string& elf_filename);
private:
- std::unique_ptr<ApkInspectorImpl> impl_;
+ static std::unique_ptr<EmbeddedElf> FindElfInApkByOffsetWithoutCache(const std::string& apk_path,
+ uint64_t file_offset);
+
+ // First component of pair is APK file path, second is offset into APK.
+ typedef std::pair<std::string, uint64_t> ApkOffset;
+
+ static std::map<ApkOffset, std::unique_ptr<EmbeddedElf>> embedded_elf_cache_;
};
+// Export for test only.
+bool IsValidApkPath(const std::string& apk_path);
+
+std::string GetUrlInApk(const std::string& apk_path, const std::string& elf_filename);
+std::tuple<bool, std::string, std::string> SplitUrlInApk(const std::string& path);
+
+bool GetBuildIdFromApkFile(const std::string& apk_path, const std::string& elf_filename,
+ BuildId* build_id);
+
+bool ParseSymbolsFromApkFile(const std::string& apk_path, const std::string& elf_filename,
+ const BuildId& expected_build_id,
+ std::function<void(const ElfFileSymbol&)> callback);
+
+
#endif // SIMPLE_PERF_READ_APK_H_
diff --git a/simpleperf/read_apk_test.cpp b/simpleperf/read_apk_test.cpp
index 5824f4b..e983a25 100644
--- a/simpleperf/read_apk_test.cpp
+++ b/simpleperf/read_apk_test.cpp
@@ -18,24 +18,44 @@
#include <gtest/gtest.h>
#include "get_test_data.h"
+#include "test_util.h"
-static const std::string fibjar = "fibonacci.jar";
-static const std::string jniapk = "has_embedded_native_libs.apk";
-TEST(read_apk, IsValidJarOrApkPath) {
- ASSERT_FALSE(IsValidJarOrApkPath("/dev/zero"));
- ASSERT_FALSE(IsValidJarOrApkPath(GetTestData("elf_file")));
- ASSERT_TRUE(IsValidJarOrApkPath(GetTestData(fibjar)));
+TEST(read_apk, IsValidApkPath) {
+ ASSERT_FALSE(IsValidApkPath("/dev/zero"));
+ ASSERT_FALSE(IsValidApkPath(GetTestData(ELF_FILE)));
+ ASSERT_TRUE(IsValidApkPath(GetTestData(APK_FILE)));
}
-TEST(read_apk, CollectEmbeddedElfInfoFromApk) {
+TEST(read_apk, FindElfInApkByOffset) {
ApkInspector inspector;
- ASSERT_TRUE(inspector.FindElfInApkByMmapOffset("/dev/null", 0) == nullptr);
- ASSERT_TRUE(inspector.FindElfInApkByMmapOffset(GetTestData(fibjar), 0) == nullptr);
- ASSERT_TRUE(inspector.FindElfInApkByMmapOffset(GetTestData(jniapk), 0) == nullptr);
- EmbeddedElf *ee1 = inspector.FindElfInApkByMmapOffset(GetTestData(jniapk), 0x91000);
- ASSERT_TRUE(ee1 != nullptr);
- ASSERT_EQ(ee1->entry_name(), "lib/armeabi-v7a/libframeworks_coretests_jni.so");
- ASSERT_TRUE(ee1->entry_offset() == 593920);
- ASSERT_TRUE(ee1->entry_size() == 13904);
+ ASSERT_TRUE(inspector.FindElfInApkByOffset("/dev/null", 0) == nullptr);
+ ASSERT_TRUE(inspector.FindElfInApkByOffset(GetTestData(APK_FILE), 0) == nullptr);
+ EmbeddedElf* ee = inspector.FindElfInApkByOffset(GetTestData(APK_FILE), 0x9000);
+ ASSERT_TRUE(ee != nullptr);
+ ASSERT_EQ(ee->entry_name(), NATIVELIB_IN_APK);
+ ASSERT_EQ(NATIVELIB_OFFSET_IN_APK, ee->entry_offset());
+ ASSERT_EQ(NATIVELIB_SIZE_IN_APK, ee->entry_size());
+}
+
+TEST(read_apk, FindElfInApkByName) {
+ ASSERT_TRUE(ApkInspector::FindElfInApkByName("/dev/null", "") == nullptr);
+ ASSERT_TRUE(ApkInspector::FindElfInApkByName(GetTestData(APK_FILE), "") == nullptr);
+ auto ee = ApkInspector::FindElfInApkByName(GetTestData(APK_FILE), NATIVELIB_IN_APK);
+ ASSERT_TRUE(ee != nullptr);
+ ASSERT_EQ(NATIVELIB_OFFSET_IN_APK, ee->entry_offset());
+ ASSERT_EQ(NATIVELIB_SIZE_IN_APK, ee->entry_size());
+}
+
+TEST(read_apk, GetBuildIdFromApkFile) {
+ BuildId build_id;
+ ASSERT_TRUE(GetBuildIdFromApkFile(GetTestData(APK_FILE), NATIVELIB_IN_APK, &build_id));
+ ASSERT_EQ(build_id, native_lib_build_id);
+}
+
+TEST(read_apk, ParseSymbolsFromApkFile) {
+ std::map<std::string, ElfFileSymbol> symbols;
+ ASSERT_TRUE(ParseSymbolsFromApkFile(GetTestData(APK_FILE), NATIVELIB_IN_APK, native_lib_build_id,
+ std::bind(ParseSymbol, std::placeholders::_1, &symbols)));
+ CheckElfFileSymbols(symbols);
}
diff --git a/simpleperf/read_elf.cpp b/simpleperf/read_elf.cpp
index 03bdcc5..db33e0e 100644
--- a/simpleperf/read_elf.cpp
+++ b/simpleperf/read_elf.cpp
@@ -19,7 +19,6 @@
#include <stdio.h>
#include <string.h>
-#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
@@ -44,15 +43,6 @@
#define ELF_NOTE_GNU "GNU"
#define NT_GNU_BUILD_ID 3
-FileHelper::FileHelper(const char *filename) : fd_(-1)
-{
- fd_ = TEMP_FAILURE_RETRY(open(filename, O_RDONLY | O_BINARY));
-}
-
-FileHelper::~FileHelper()
-{
- if (fd_ != -1) { close(fd_); }
-}
bool IsValidElfFile(int fd) {
static const char elf_magic[] = {0x7f, 'E', 'L', 'F'};
@@ -145,62 +135,67 @@
return result;
}
-bool GetBuildIdFromEmbeddedElfFile(const std::string& filename,
- uint64_t offsetInFile,
- uint32_t sizeInFile,
- BuildId* build_id) {
- FileHelper opener(filename.c_str());
- if (opener.fd() == -1) {
- LOG(DEBUG) << "unable to open " << filename
- << "to collect embedded ELF build id";
- return false;
+struct BinaryRet {
+ llvm::object::OwningBinary<llvm::object::Binary> binary;
+ llvm::object::ObjectFile* obj;
+
+ BinaryRet() : obj(nullptr) {
}
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> bufferOrErr =
- llvm::MemoryBuffer::getOpenFileSlice(opener.fd(), filename, sizeInFile,
- offsetInFile);
- if (std::error_code EC = bufferOrErr.getError()) {
- LOG(DEBUG) << "MemoryBuffer::getOpenFileSlice failed opening "
- << filename << "while collecting embedded ELF build id: "
- << EC.message();
- return false;
+};
+
+static BinaryRet OpenObjectFile(const std::string& filename, uint64_t file_offset = 0,
+ uint64_t file_size = 0) {
+ BinaryRet ret;
+ FileHelper fhelper = FileHelper::OpenReadOnly(filename);
+ if (!fhelper) {
+ PLOG(DEBUG) << "failed to open " << filename;
+ return ret;
}
- std::unique_ptr<llvm::MemoryBuffer> buffer = std::move(bufferOrErr.get());
- llvm::LLVMContext *context = nullptr;
- llvm::ErrorOr<std::unique_ptr<llvm::object::Binary>> binaryOrErr =
- llvm::object::createBinary(buffer->getMemBufferRef(), context);
- if (std::error_code EC = binaryOrErr.getError()) {
- LOG(DEBUG) << "llvm::object::createBinary failed opening "
- << filename << "while collecting embedded ELF build id: "
- << EC.message();
- return false;
+ if (file_size == 0) {
+ file_size = GetFileSize(filename);
+ if (file_size == 0) {
+ PLOG(ERROR) << "failed to get size of file " << filename;
+ return ret;
+ }
}
- std::unique_ptr<llvm::object::Binary> binary = std::move(binaryOrErr.get());
- auto obj = llvm::dyn_cast<llvm::object::ObjectFile>(binary.get());
- if (obj == nullptr) {
- LOG(DEBUG) << "unable to cast to interpret contents of " << filename
- << "at offset " << offsetInFile
- << ": failed to cast to llvm::object::ObjectFile";
- return false;
+ auto buffer_or_err = llvm::MemoryBuffer::getOpenFileSlice(fhelper.fd(), filename, file_size, file_offset);
+ if (!buffer_or_err) {
+ LOG(ERROR) << "failed to read " << filename << " [" << file_offset << "-" << (file_offset + file_size)
+ << "]: " << buffer_or_err.getError().message();
+ return ret;
}
- return GetBuildIdFromObjectFile(obj, build_id);
+ auto binary_or_err = llvm::object::createBinary(buffer_or_err.get()->getMemBufferRef());
+ if (!binary_or_err) {
+ LOG(ERROR) << filename << " [" << file_offset << "-" << (file_offset + file_size)
+ << "] is not a binary file: " << binary_or_err.getError().message();
+ return ret;
+ }
+ ret.binary = llvm::object::OwningBinary<llvm::object::Binary>(std::move(binary_or_err.get()),
+ std::move(buffer_or_err.get()));
+ ret.obj = llvm::dyn_cast<llvm::object::ObjectFile>(ret.binary.getBinary());
+ if (ret.obj == nullptr) {
+ LOG(ERROR) << filename << " [" << file_offset << "-" << (file_offset + file_size)
+ << "] is not an object file";
+ }
+ return ret;
}
bool GetBuildIdFromElfFile(const std::string& filename, BuildId* build_id) {
if (!IsValidElfPath(filename)) {
return false;
}
- auto owning_binary = llvm::object::createBinary(llvm::StringRef(filename));
- if (owning_binary.getError()) {
- PLOG(DEBUG) << "can't open file " << filename;
+ bool result = GetBuildIdFromEmbeddedElfFile(filename, 0, 0, build_id);
+ LOG(VERBOSE) << "GetBuildIdFromElfFile(" << filename << ") => " << build_id->ToString();
+ return result;
+}
+
+bool GetBuildIdFromEmbeddedElfFile(const std::string& filename, uint64_t file_offset,
+ uint32_t file_size, BuildId* build_id) {
+ BinaryRet ret = OpenObjectFile(filename, file_offset, file_size);
+ if (ret.obj == nullptr) {
return false;
}
- llvm::object::Binary* binary = owning_binary.get().getBinary();
- auto obj = llvm::dyn_cast<llvm::object::ObjectFile>(binary);
- if (obj == nullptr) {
- LOG(DEBUG) << filename << " is not an object file";
- return false;
- }
- return GetBuildIdFromObjectFile(obj, build_id);
+ return GetBuildIdFromObjectFile(ret.obj, build_id);
}
bool IsArmMappingSymbol(const char* name) {
@@ -274,31 +269,22 @@
}
}
-static llvm::object::ObjectFile* GetObjectFile(
- llvm::ErrorOr<llvm::object::OwningBinary<llvm::object::Binary>>& owning_binary,
- const std::string& filename, const BuildId& expected_build_id) {
- if (owning_binary.getError()) {
- PLOG(DEBUG) << "can't open file '" << filename << "'";
- return nullptr;
+bool MatchBuildId(llvm::object::ObjectFile* obj, const BuildId& expected_build_id,
+ const std::string& debug_filename) {
+ if (expected_build_id.IsEmpty()) {
+ return true;
}
- llvm::object::Binary* binary = owning_binary.get().getBinary();
- auto obj = llvm::dyn_cast<llvm::object::ObjectFile>(binary);
- if (obj == nullptr) {
- LOG(DEBUG) << filename << " is not an object file";
- return nullptr;
+ BuildId real_build_id;
+ if (!GetBuildIdFromObjectFile(obj, &real_build_id)) {
+ return false;
}
- if (!expected_build_id.IsEmpty()) {
- BuildId real_build_id;
- GetBuildIdFromObjectFile(obj, &real_build_id);
- bool result = (expected_build_id == real_build_id);
- LOG(DEBUG) << "check build id for \"" << filename << "\" (" << (result ? "match" : "mismatch")
- << "): expected " << expected_build_id.ToString() << ", real "
- << real_build_id.ToString();
- if (!result) {
- return nullptr;
- }
+ if (expected_build_id != real_build_id) {
+ LOG(DEBUG) << "build id for " << debug_filename << " mismatch: "
+ << "expected " << expected_build_id.ToString()
+ << ", real " << real_build_id.ToString();
+ return false;
}
- return obj;
+ return true;
}
bool ParseSymbolsFromElfFile(const std::string& filename, const BuildId& expected_build_id,
@@ -306,18 +292,22 @@
if (!IsValidElfPath(filename)) {
return false;
}
- auto owning_binary = llvm::object::createBinary(llvm::StringRef(filename));
- llvm::object::ObjectFile* obj = GetObjectFile(owning_binary, filename, expected_build_id);
- if (obj == nullptr) {
+ return ParseSymbolsFromEmbeddedElfFile(filename, 0, 0, expected_build_id, callback);
+}
+
+bool ParseSymbolsFromEmbeddedElfFile(const std::string& filename, uint64_t file_offset,
+ uint32_t file_size, const BuildId& expected_build_id,
+ std::function<void(const ElfFileSymbol&)> callback) {
+ BinaryRet ret = OpenObjectFile(filename, file_offset, file_size);
+ if (ret.obj == nullptr || !MatchBuildId(ret.obj, expected_build_id, filename)) {
return false;
}
-
- if (auto elf = llvm::dyn_cast<llvm::object::ELF32LEObjectFile>(obj)) {
+ if (auto elf = llvm::dyn_cast<llvm::object::ELF32LEObjectFile>(ret.obj)) {
ParseSymbolsFromELFFile(elf->getELFFile(), callback);
- } else if (auto elf = llvm::dyn_cast<llvm::object::ELF64LEObjectFile>(obj)) {
+ } else if (auto elf = llvm::dyn_cast<llvm::object::ELF64LEObjectFile>(ret.obj)) {
ParseSymbolsFromELFFile(elf->getELFFile(), callback);
} else {
- LOG(ERROR) << "unknown elf format in file" << filename;
+ LOG(ERROR) << "unknown elf format in file " << filename;
return false;
}
return true;
@@ -347,16 +337,15 @@
if (!IsValidElfPath(filename)) {
return false;
}
- auto owning_binary = llvm::object::createBinary(llvm::StringRef(filename));
- llvm::object::ObjectFile* obj = GetObjectFile(owning_binary, filename, expected_build_id);
- if (obj == nullptr) {
+ BinaryRet ret = OpenObjectFile(filename);
+ if (ret.obj == nullptr || !MatchBuildId(ret.obj, expected_build_id, filename)) {
return false;
}
bool result = false;
- if (auto elf = llvm::dyn_cast<llvm::object::ELF32LEObjectFile>(obj)) {
+ if (auto elf = llvm::dyn_cast<llvm::object::ELF32LEObjectFile>(ret.obj)) {
result = ReadMinExecutableVirtualAddress(elf->getELFFile(), min_vaddr);
- } else if (auto elf = llvm::dyn_cast<llvm::object::ELF64LEObjectFile>(obj)) {
+ } else if (auto elf = llvm::dyn_cast<llvm::object::ELF64LEObjectFile>(ret.obj)) {
result = ReadMinExecutableVirtualAddress(elf->getELFFile(), min_vaddr);
} else {
LOG(ERROR) << "unknown elf format in file" << filename;
@@ -368,3 +357,43 @@
}
return result;
}
+
+template <class ELFT>
+bool ReadSectionFromELFFile(const llvm::object::ELFFile<ELFT>* elf, const std::string& section_name,
+ std::string* content) {
+ for (auto it = elf->begin_sections(); it != elf->end_sections(); ++it) {
+ auto name_or_err = elf->getSectionName(&*it);
+ if (name_or_err && *name_or_err == section_name) {
+ auto data_or_err = elf->getSectionContents(&*it);
+ if (!data_or_err) {
+ LOG(ERROR) << "failed to read section " << section_name;
+ return false;
+ }
+ content->append(data_or_err->begin(), data_or_err->end());
+ return true;
+ }
+ }
+ LOG(ERROR) << "can't find section " << section_name;
+ return false;
+}
+
+bool ReadSectionFromElfFile(const std::string& filename, const std::string& section_name,
+ std::string* content) {
+ if (!IsValidElfPath(filename)) {
+ return false;
+ }
+ BinaryRet ret = OpenObjectFile(filename);
+ if (ret.obj == nullptr) {
+ return false;
+ }
+ bool result = false;
+ if (auto elf = llvm::dyn_cast<llvm::object::ELF32LEObjectFile>(ret.obj)) {
+ result = ReadSectionFromELFFile(elf->getELFFile(), section_name, content);
+ } else if (auto elf = llvm::dyn_cast<llvm::object::ELF64LEObjectFile>(ret.obj)) {
+ result = ReadSectionFromELFFile(elf->getELFFile(), section_name, content);
+ } else {
+ LOG(ERROR) << "unknown elf format in file" << filename;
+ return false;
+ }
+ return result;
+}
diff --git a/simpleperf/read_elf.h b/simpleperf/read_elf.h
index d0626b2..a6c73c3 100644
--- a/simpleperf/read_elf.h
+++ b/simpleperf/read_elf.h
@@ -23,10 +23,8 @@
bool GetBuildIdFromNoteFile(const std::string& filename, BuildId* build_id);
bool GetBuildIdFromElfFile(const std::string& filename, BuildId* build_id);
-bool GetBuildIdFromEmbeddedElfFile(const std::string& filename,
- uint64_t offsetInFile,
- uint32_t sizeInFile,
- BuildId* build_id);
+bool GetBuildIdFromEmbeddedElfFile(const std::string& filename, uint64_t file_offset,
+ uint32_t file_size, BuildId* build_id);
// The symbol prefix used to indicate that the symbol belongs to android linker.
static const std::string linker_prefix = "__dl_";
@@ -45,21 +43,16 @@
bool ParseSymbolsFromElfFile(const std::string& filename, const BuildId& expected_build_id,
std::function<void(const ElfFileSymbol&)> callback);
+bool ParseSymbolsFromEmbeddedElfFile(const std::string& filename, uint64_t file_offset,
+ uint32_t file_size, const BuildId& expected_build_id,
+ std::function<void(const ElfFileSymbol&)> callback);
bool ReadMinExecutableVirtualAddressFromElfFile(const std::string& filename,
const BuildId& expected_build_id,
uint64_t* min_addr);
-// Opens file in constructor, then closes file when object is destroyed.
-class FileHelper {
- public:
- explicit FileHelper(const char *filename);
- ~FileHelper();
- int fd() const { return fd_; }
-
- private:
- int fd_;
-};
+bool ReadSectionFromElfFile(const std::string& filename, const std::string& section_name,
+ std::string* content);
// Expose the following functions for unit tests.
bool IsArmMappingSymbol(const char* name);
diff --git a/simpleperf/read_elf_test.cpp b/simpleperf/read_elf_test.cpp
index 7a5194e..929540f 100644
--- a/simpleperf/read_elf_test.cpp
+++ b/simpleperf/read_elf_test.cpp
@@ -21,22 +21,24 @@
#include <map>
#include "get_test_data.h"
-static const unsigned char elf_file_build_id[] = {
- 0x76, 0x00, 0x32, 0x9e, 0x31, 0x05, 0x8e, 0x12, 0xb1, 0x45,
- 0xd1, 0x53, 0xef, 0x27, 0xcd, 0x40, 0xe1, 0xa5, 0xf7, 0xb9
-};
-
TEST(read_elf, GetBuildIdFromElfFile) {
BuildId build_id;
- ASSERT_TRUE(GetBuildIdFromElfFile(GetTestData("elf_file"), &build_id));
+ ASSERT_TRUE(GetBuildIdFromElfFile(GetTestData(ELF_FILE), &build_id));
ASSERT_EQ(build_id, BuildId(elf_file_build_id));
}
-static void ParseSymbol(const ElfFileSymbol& symbol, std::map<std::string, ElfFileSymbol>* symbols) {
+TEST(read_elf, GetBuildIdFromEmbeddedElfFile) {
+ BuildId build_id;
+ ASSERT_TRUE(GetBuildIdFromEmbeddedElfFile(GetTestData(APK_FILE), NATIVELIB_OFFSET_IN_APK,
+ NATIVELIB_SIZE_IN_APK, &build_id));
+ ASSERT_EQ(build_id, native_lib_build_id);
+}
+
+void ParseSymbol(const ElfFileSymbol& symbol, std::map<std::string, ElfFileSymbol>* symbols) {
(*symbols)[symbol.name] = symbol;
}
-static void CheckElfFileSymbols(const std::map<std::string, ElfFileSymbol>& symbols) {
+void CheckElfFileSymbols(const std::map<std::string, ElfFileSymbol>& symbols) {
auto pos = symbols.find("GlobalVar");
ASSERT_NE(pos, symbols.end());
ASSERT_FALSE(pos->second.is_func);
@@ -47,28 +49,34 @@
}
TEST(read_elf, parse_symbols_from_elf_file_with_correct_build_id) {
- BuildId build_id(elf_file_build_id);
std::map<std::string, ElfFileSymbol> symbols;
- ASSERT_TRUE(ParseSymbolsFromElfFile(GetTestData("elf_file"), build_id,
+ ASSERT_TRUE(ParseSymbolsFromElfFile(GetTestData(ELF_FILE), elf_file_build_id,
std::bind(ParseSymbol, std::placeholders::_1, &symbols)));
CheckElfFileSymbols(symbols);
}
TEST(read_elf, parse_symbols_from_elf_file_without_build_id) {
- BuildId build_id;
std::map<std::string, ElfFileSymbol> symbols;
- ASSERT_TRUE(ParseSymbolsFromElfFile(GetTestData("elf_file"), build_id,
+ ASSERT_TRUE(ParseSymbolsFromElfFile(GetTestData(ELF_FILE), BuildId(),
std::bind(ParseSymbol, std::placeholders::_1, &symbols)));
CheckElfFileSymbols(symbols);
}
TEST(read_elf, parse_symbols_from_elf_file_with_wrong_build_id) {
- BuildId build_id("wrong_build_id");
+ BuildId build_id("01010101010101010101");
std::map<std::string, ElfFileSymbol> symbols;
- ASSERT_FALSE(ParseSymbolsFromElfFile(GetTestData("elf_file"), build_id,
+ ASSERT_FALSE(ParseSymbolsFromElfFile(GetTestData(ELF_FILE), build_id,
std::bind(ParseSymbol, std::placeholders::_1, &symbols)));
}
+TEST(read_elf, ParseSymbolsFromEmbeddedElfFile) {
+ std::map<std::string, ElfFileSymbol> symbols;
+ ASSERT_TRUE(ParseSymbolsFromEmbeddedElfFile(GetTestData(APK_FILE), NATIVELIB_OFFSET_IN_APK,
+ NATIVELIB_SIZE_IN_APK, native_lib_build_id,
+ std::bind(ParseSymbol, std::placeholders::_1, &symbols)));
+ CheckElfFileSymbols(symbols);
+}
+
TEST(read_elf, arm_mapping_symbol) {
ASSERT_TRUE(IsArmMappingSymbol("$a"));
ASSERT_FALSE(IsArmMappingSymbol("$b"));
@@ -79,5 +87,5 @@
TEST(read_elf, IsValidElfPath) {
ASSERT_FALSE(IsValidElfPath("/dev/zero"));
ASSERT_FALSE(IsValidElfPath("/sys/devices/system/cpu/online"));
- ASSERT_TRUE(IsValidElfPath(GetTestData("elf_file")));
+ ASSERT_TRUE(IsValidElfPath(GetTestData(ELF_FILE)));
}
diff --git a/simpleperf/record.cpp b/simpleperf/record.cpp
index 26bc588..f9e220a 100644
--- a/simpleperf/record.cpp
+++ b/simpleperf/record.cpp
@@ -73,25 +73,7 @@
sample_id_all = attr.sample_id_all;
sample_type = attr.sample_type;
// Other data are not necessary. TODO: Set missing SampleId data.
- size_t size = 0;
- if (sample_id_all) {
- if (sample_type & PERF_SAMPLE_TID) {
- size += sizeof(PerfSampleTidType);
- }
- if (sample_type & PERF_SAMPLE_TIME) {
- size += sizeof(PerfSampleTimeType);
- }
- if (sample_type & PERF_SAMPLE_ID) {
- size += sizeof(PerfSampleIdType);
- }
- if (sample_type & PERF_SAMPLE_STREAM_ID) {
- size += sizeof(PerfSampleStreamIdType);
- }
- if (sample_type & PERF_SAMPLE_CPU) {
- size += sizeof(PerfSampleCpuType);
- }
- }
- return size;
+ return Size();
}
void SampleId::ReadFromBinaryFormat(const perf_event_attr& attr, const char* p, const char* end) {
@@ -161,6 +143,28 @@
}
}
+size_t SampleId::Size() const {
+ size_t size = 0;
+ if (sample_id_all) {
+ if (sample_type & PERF_SAMPLE_TID) {
+ size += sizeof(PerfSampleTidType);
+ }
+ if (sample_type & PERF_SAMPLE_TIME) {
+ size += sizeof(PerfSampleTimeType);
+ }
+ if (sample_type & PERF_SAMPLE_ID) {
+ size += sizeof(PerfSampleIdType);
+ }
+ if (sample_type & PERF_SAMPLE_STREAM_ID) {
+ size += sizeof(PerfSampleStreamIdType);
+ }
+ if (sample_type & PERF_SAMPLE_CPU) {
+ size += sizeof(PerfSampleCpuType);
+ }
+ }
+ return size;
+}
+
Record::Record() {
memset(&header, 0, sizeof(header));
}
@@ -202,6 +206,10 @@
return buf;
}
+void MmapRecord::AdjustSizeBasedOnData() {
+ header.size = sizeof(header) + sizeof(data) + ALIGN(filename.size() + 1, 8) + sample_id.Size();
+}
+
void MmapRecord::DumpData(size_t indent) const {
PrintIndented(indent, "pid %u, tid %u, addr 0x%" PRIx64 ", len 0x%" PRIx64 "\n", data.pid,
data.tid, data.addr, data.len);
@@ -230,6 +238,10 @@
return buf;
}
+void Mmap2Record::AdjustSizeBasedOnData() {
+ header.size = sizeof(header) + sizeof(data) + ALIGN(filename.size() + 1, 8) + sample_id.Size();
+}
+
void Mmap2Record::DumpData(size_t indent) const {
PrintIndented(indent, "pid %u, tid %u, addr 0x%" PRIx64 ", len 0x%" PRIx64 "\n", data.pid,
data.tid, data.addr, data.len);
@@ -437,7 +449,8 @@
void SampleRecord::AdjustSizeBasedOnData() {
size_t size = BinaryFormat().size();
- LOG(DEBUG) << "SampleRecord size is changed from " << header.size << " to " << size;
+ LOG(DEBUG) << "Record (type " << RecordTypeToString(header.type) << ") size is changed from "
+ << header.size << " to " << size;
header.size = size;
}
@@ -521,7 +534,7 @@
const char* p = reinterpret_cast<const char*>(pheader + 1);
const char* end = reinterpret_cast<const char*>(pheader) + pheader->size;
MoveFromBinaryFormat(pid, p);
- build_id = BuildId(p);
+ build_id = BuildId(p, BUILD_ID_SIZE);
p += ALIGN(build_id.Size(), 8);
filename = p;
p += ALIGN(filename.size() + 1, 64);
@@ -536,7 +549,6 @@
memcpy(p, build_id.Data(), build_id.Size());
p += ALIGN(build_id.Size(), 8);
strcpy(p, filename.c_str());
- p += ALIGN(filename.size() + 1, 64);
return buf;
}
@@ -632,18 +644,6 @@
return record;
}
-void UpdateMmapRecord(MmapRecord *record, const std::string& new_filename, uint64_t new_pgoff)
-{
- size_t new_filename_size = ALIGN(new_filename.size() + 1, 8);
- size_t old_filename_size = ALIGN(record->filename.size() + 1, 8);
- record->data.pgoff = new_pgoff;
- record->filename = new_filename;
- if (new_filename_size > old_filename_size)
- record->header.size += (new_filename_size - old_filename_size);
- else if (new_filename_size < old_filename_size)
- record->header.size += (old_filename_size - new_filename_size);
-}
-
CommRecord CreateCommRecord(const perf_event_attr& attr, uint32_t pid, uint32_t tid,
const std::string& comm) {
CommRecord record;
diff --git a/simpleperf/record.h b/simpleperf/record.h
index 26e4599..a94a917 100644
--- a/simpleperf/record.h
+++ b/simpleperf/record.h
@@ -125,6 +125,7 @@
// Write the binary format of sample_id to the buffer pointed by p.
void WriteToBinaryFormat(char*& p) const;
void Dump(size_t indent) const;
+ size_t Size() const;
};
// Usually one record contains the following three parts in order in binary format:
@@ -173,6 +174,7 @@
MmapRecord(const perf_event_attr& attr, const perf_event_header* pheader);
std::vector<char> BinaryFormat() const override;
+ void AdjustSizeBasedOnData();
protected:
void DumpData(size_t indent) const override;
@@ -197,6 +199,7 @@
Mmap2Record(const perf_event_attr& attr, const perf_event_header* pheader);
std::vector<char> BinaryFormat() const override;
+ void AdjustSizeBasedOnData();
protected:
void DumpData(size_t indent) const override;
@@ -357,7 +360,6 @@
MmapRecord CreateMmapRecord(const perf_event_attr& attr, bool in_kernel, uint32_t pid, uint32_t tid,
uint64_t addr, uint64_t len, uint64_t pgoff,
const std::string& filename);
-void UpdateMmapRecord(MmapRecord *record, const std::string& new_filename, uint64_t new_pgoff);
CommRecord CreateCommRecord(const perf_event_attr& attr, uint32_t pid, uint32_t tid,
const std::string& comm);
ForkRecord CreateForkRecord(const perf_event_attr& attr, uint32_t pid, uint32_t tid, uint32_t ppid,
diff --git a/simpleperf/record_file_reader.cpp b/simpleperf/record_file_reader.cpp
index 1befc69..f126a6b 100644
--- a/simpleperf/record_file_reader.cpp
+++ b/simpleperf/record_file_reader.cpp
@@ -75,9 +75,8 @@
bool RecordFileReader::ReadAttrSection() {
size_t attr_count = header_.attrs.size / header_.attr_size;
if (header_.attr_size != sizeof(FileAttr)) {
- LOG(WARNING) << "attr size (" << header_.attr_size << ") in " << filename_
+ LOG(DEBUG) << "attr size (" << header_.attr_size << ") in " << filename_
<< " doesn't match expected size (" << sizeof(FileAttr) << ")";
- return false;
}
if (attr_count == 0) {
LOG(ERROR) << "no attr in file " << filename_;
diff --git a/simpleperf/record_file_test.cpp b/simpleperf/record_file_test.cpp
index e4f963e..4648a64 100644
--- a/simpleperf/record_file_test.cpp
+++ b/simpleperf/record_file_test.cpp
@@ -20,6 +20,8 @@
#include <memory>
+#include <android-base/test_utils.h>
+
#include "environment.h"
#include "event_attr.h"
#include "event_type.h"
@@ -32,10 +34,6 @@
class RecordFileTest : public ::testing::Test {
protected:
- virtual void SetUp() {
- filename_ = "temporary.record_file";
- }
-
void AddEventType(const std::string& event_type_str) {
std::unique_ptr<EventTypeAndModifier> event_type_modifier = ParseEventType(event_type_str);
ASSERT_TRUE(event_type_modifier != nullptr);
@@ -47,14 +45,14 @@
attr_ids_.push_back(attr_id);
}
- std::string filename_;
+ TemporaryFile tmpfile_;
std::vector<std::unique_ptr<perf_event_attr>> attrs_;
std::vector<AttrWithId> attr_ids_;
};
TEST_F(RecordFileTest, smoke) {
// Write to a record file.
- std::unique_ptr<RecordFileWriter> writer = RecordFileWriter::CreateInstance(filename_);
+ std::unique_ptr<RecordFileWriter> writer = RecordFileWriter::CreateInstance(tmpfile_.path);
ASSERT_TRUE(writer != nullptr);
// Write attr section.
@@ -78,7 +76,7 @@
ASSERT_TRUE(writer->Close());
// Read from a record file.
- std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(filename_);
+ std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(tmpfile_.path);
ASSERT_TRUE(reader != nullptr);
const std::vector<FileAttr>& file_attrs = reader->AttrSection();
ASSERT_EQ(1u, file_attrs.size());
@@ -102,7 +100,7 @@
TEST_F(RecordFileTest, records_sorted_by_time) {
// Write to a record file.
- std::unique_ptr<RecordFileWriter> writer = RecordFileWriter::CreateInstance(filename_);
+ std::unique_ptr<RecordFileWriter> writer = RecordFileWriter::CreateInstance(tmpfile_.path);
ASSERT_TRUE(writer != nullptr);
// Write attr section.
@@ -125,7 +123,7 @@
ASSERT_TRUE(writer->Close());
// Read from a record file.
- std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(filename_);
+ std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(tmpfile_.path);
ASSERT_TRUE(reader != nullptr);
std::vector<std::unique_ptr<Record>> records = reader->DataSection();
ASSERT_EQ(3u, records.size());
@@ -138,7 +136,7 @@
TEST_F(RecordFileTest, record_more_than_one_attr) {
// Write to a record file.
- std::unique_ptr<RecordFileWriter> writer = RecordFileWriter::CreateInstance(filename_);
+ std::unique_ptr<RecordFileWriter> writer = RecordFileWriter::CreateInstance(tmpfile_.path);
ASSERT_TRUE(writer != nullptr);
// Write attr section.
@@ -150,7 +148,7 @@
ASSERT_TRUE(writer->Close());
// Read from a record file.
- std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(filename_);
+ std::unique_ptr<RecordFileReader> reader = RecordFileReader::CreateInstance(tmpfile_.path);
ASSERT_TRUE(reader != nullptr);
const std::vector<FileAttr>& file_attrs = reader->AttrSection();
ASSERT_EQ(3u, file_attrs.size());
diff --git a/simpleperf/test_util.h b/simpleperf/test_util.h
index 34155a3..cfbe493 100644
--- a/simpleperf/test_util.h
+++ b/simpleperf/test_util.h
@@ -14,14 +14,15 @@
* limitations under the License.
*/
+#include <map>
+#include <string>
+
+#include "read_elf.h"
#include "workload.h"
-static void CreateProcesses(size_t count, std::vector<std::unique_ptr<Workload>>* workloads) {
- workloads->clear();
- for (size_t i = 0; i < count; ++i) {
- auto workload = Workload::CreateWorkload({"sleep", "1"});
- ASSERT_TRUE(workload != nullptr);
- ASSERT_TRUE(workload->Start());
- workloads->push_back(std::move(workload));
- }
-}
+static const std::string SLEEP_SEC = "0.001";
+
+void CreateProcesses(size_t count, std::vector<std::unique_ptr<Workload>>* workloads);
+
+void ParseSymbol(const ElfFileSymbol& symbol, std::map<std::string, ElfFileSymbol>* symbols);
+void CheckElfFileSymbols(const std::map<std::string, ElfFileSymbol>& symbols);
diff --git a/simpleperf/testdata/data/app/com.example.hellojni-1/base.apk b/simpleperf/testdata/data/app/com.example.hellojni-1/base.apk
new file mode 100644
index 0000000..c757e9e
--- /dev/null
+++ b/simpleperf/testdata/data/app/com.example.hellojni-1/base.apk
Binary files differ
diff --git a/simpleperf/testdata/elf b/simpleperf/testdata/elf
new file mode 100644
index 0000000..f63c25c
--- /dev/null
+++ b/simpleperf/testdata/elf
Binary files differ
diff --git a/simpleperf/testdata/elf_file b/simpleperf/testdata/elf_file
deleted file mode 100644
index 53b589a..0000000
--- a/simpleperf/testdata/elf_file
+++ /dev/null
Binary files differ
diff --git a/simpleperf/testdata/elf_file_source.cpp b/simpleperf/testdata/elf_file_source.cpp
new file mode 100644
index 0000000..3cfd00b
--- /dev/null
+++ b/simpleperf/testdata/elf_file_source.cpp
@@ -0,0 +1,20 @@
+#include <pthread.h>
+
+volatile int GlobalVar;
+
+extern "C" void CalledFunc() {
+ GlobalVar++;
+}
+
+extern "C" void GlobalFunc() {
+ for (int i = 0; i < 1000000; ++i) {
+ CalledFunc();
+ }
+}
+
+int main() {
+ while (true) {
+ GlobalFunc();
+ }
+ return 0;
+}
diff --git a/simpleperf/testdata/fibonacci.jar b/simpleperf/testdata/fibonacci.jar
deleted file mode 100644
index df57e40..0000000
--- a/simpleperf/testdata/fibonacci.jar
+++ /dev/null
Binary files differ
diff --git a/simpleperf/testdata/has_embedded_native_libs.apk b/simpleperf/testdata/has_embedded_native_libs.apk
deleted file mode 100644
index 2a1924c..0000000
--- a/simpleperf/testdata/has_embedded_native_libs.apk
+++ /dev/null
Binary files differ
diff --git a/simpleperf/testdata/has_embedded_native_libs_apk_perf.data b/simpleperf/testdata/has_embedded_native_libs_apk_perf.data
new file mode 100644
index 0000000..f85c9d3
--- /dev/null
+++ b/simpleperf/testdata/has_embedded_native_libs_apk_perf.data
Binary files differ
diff --git a/simpleperf/testdata/perf.data b/simpleperf/testdata/perf.data
new file mode 100644
index 0000000..64a59da
--- /dev/null
+++ b/simpleperf/testdata/perf.data
Binary files differ
diff --git a/simpleperf/testdata/perf_b.data b/simpleperf/testdata/perf_b.data
new file mode 100644
index 0000000..e514944
--- /dev/null
+++ b/simpleperf/testdata/perf_b.data
Binary files differ
diff --git a/simpleperf/testdata/perf_g_fp.data b/simpleperf/testdata/perf_g_fp.data
new file mode 100644
index 0000000..de9cf53
--- /dev/null
+++ b/simpleperf/testdata/perf_g_fp.data
Binary files differ
diff --git a/simpleperf/utils.cpp b/simpleperf/utils.cpp
index eabad29..99e1e98 100644
--- a/simpleperf/utils.cpp
+++ b/simpleperf/utils.cpp
@@ -18,6 +18,7 @@
#include <dirent.h>
#include <errno.h>
+#include <fcntl.h>
#include <stdarg.h>
#include <stdio.h>
#include <sys/stat.h>
@@ -26,6 +27,7 @@
#include <algorithm>
#include <string>
+#include <android-base/file.h>
#include <android-base/logging.h>
void OneTimeFreeAllocator::Clear() {
@@ -52,6 +54,38 @@
return result;
}
+
+FileHelper FileHelper::OpenReadOnly(const std::string& filename) {
+ int fd = TEMP_FAILURE_RETRY(open(filename.c_str(), O_RDONLY | O_BINARY));
+ return FileHelper(fd);
+}
+
+FileHelper FileHelper::OpenWriteOnly(const std::string& filename) {
+ int fd = TEMP_FAILURE_RETRY(open(filename.c_str(), O_WRONLY | O_BINARY | O_CREAT, 0644));
+ return FileHelper(fd);
+}
+
+FileHelper::~FileHelper() {
+ if (fd_ != -1) {
+ close(fd_);
+ }
+}
+
+ArchiveHelper::ArchiveHelper(int fd, const std::string& debug_filename) : valid_(false) {
+ int rc = OpenArchiveFd(fd, "", &handle_, false);
+ if (rc == 0) {
+ valid_ = true;
+ } else {
+ LOG(ERROR) << "Failed to open archive " << debug_filename << ": " << ErrorCodeString(rc);
+ }
+}
+
+ArchiveHelper::~ArchiveHelper() {
+ if (valid_) {
+ CloseArchive(handle_);
+ }
+}
+
void PrintIndented(size_t indent, const char* fmt, ...) {
va_list ap;
va_start(ap, fmt);
@@ -114,3 +148,35 @@
}
return false;
}
+
+uint64_t GetFileSize(const std::string& filename) {
+ struct stat st;
+ if (stat(filename.c_str(), &st) == 0) {
+ return static_cast<uint64_t>(st.st_size);
+ }
+ return 0;
+}
+
+bool MkdirWithParents(const std::string& path) {
+ size_t prev_end = 0;
+ while (prev_end < path.size()) {
+ size_t next_end = path.find('/', prev_end + 1);
+ if (next_end == std::string::npos) {
+ break;
+ }
+ std::string dir_path = path.substr(0, next_end);
+ if (!IsDir(dir_path)) {
+#if defined(_WIN32)
+ int ret = mkdir(dir_path.c_str());
+#else
+ int ret = mkdir(dir_path.c_str(), 0755);
+#endif
+ if (ret != 0) {
+ PLOG(ERROR) << "failed to create dir " << dir_path;
+ return false;
+ }
+ }
+ prev_end = next_end;
+ }
+ return true;
+}
diff --git a/simpleperf/utils.h b/simpleperf/utils.h
index 2ce0726..1164b1e 100644
--- a/simpleperf/utils.h
+++ b/simpleperf/utils.h
@@ -22,6 +22,9 @@
#include <string>
#include <vector>
+#include <android-base/macros.h>
+#include <ziparchive/zip_archive.h>
+
#define ALIGN(value, alignment) (((value) + (alignment)-1) & ~((alignment)-1))
#ifdef _WIN32
@@ -52,6 +55,53 @@
char* end_;
};
+class FileHelper {
+ public:
+ static FileHelper OpenReadOnly(const std::string& filename);
+ static FileHelper OpenWriteOnly(const std::string& filename);
+
+ FileHelper(FileHelper&& other) {
+ fd_ = other.fd_;
+ other.fd_ = -1;
+ }
+
+ ~FileHelper();
+
+ explicit operator bool() const {
+ return fd_ != -1;
+ }
+
+ int fd() const {
+ return fd_;
+ }
+
+ private:
+ FileHelper(int fd) : fd_(fd) {}
+ int fd_;
+
+ DISALLOW_COPY_AND_ASSIGN(FileHelper);
+};
+
+
+class ArchiveHelper {
+ public:
+ ArchiveHelper(int fd, const std::string& debug_filename);
+ ~ArchiveHelper();
+
+ explicit operator bool() const {
+ return valid_;
+ }
+ ZipArchiveHandle &archive_handle() {
+ return handle_;
+ }
+
+ private:
+ ZipArchiveHandle handle_;
+ bool valid_;
+
+ DISALLOW_COPY_AND_ASSIGN(ArchiveHelper);
+};
+
template <class T>
void MoveFromBinaryFormat(T& data, const char*& p) {
data = *reinterpret_cast<const T*>(p);
@@ -66,5 +116,7 @@
std::vector<std::string>* subdirs);
bool IsDir(const std::string& dirpath);
bool IsRegularFile(const std::string& filename);
+uint64_t GetFileSize(const std::string& filename);
+bool MkdirWithParents(const std::string& path);
#endif // SIMPLE_PERF_UTILS_H_
diff --git a/tests/ext4/rand_emmc_perf.c b/tests/ext4/rand_emmc_perf.c
index ebd10c8..fed7a54 100644
--- a/tests/ext4/rand_emmc_perf.c
+++ b/tests/ext4/rand_emmc_perf.c
@@ -51,8 +51,8 @@
{
int i;
struct timeval t;
- struct timeval sum = { 0 };
- struct timeval max = { 0 };
+ struct timeval sum = { 0, 0 };
+ struct timeval max = { 0, 0 };
long long total_usecs;
long long avg_usecs;
long long max_usecs;
@@ -217,6 +217,7 @@
break;
case 'f':
+ free(full_stats_file);
full_stats_file = strdup(optarg);
if (full_stats_file == NULL) {
fprintf(stderr, "Cannot get full stats filename\n");
@@ -258,6 +259,7 @@
} else {
perf_test(fd, write_mode, max_blocks);
}
+ free(full_stats_file);
exit(0);
}
diff --git a/tests/net_test/README b/tests/net_test/README
deleted file mode 100644
index f45c3d5..0000000
--- a/tests/net_test/README
+++ /dev/null
@@ -1,77 +0,0 @@
- net_test v0.1
- =============
-
-A simple framework for blackbox testing of kernel networking code.
-
-
-Why use it?
-===========
-
-- Fast test / boot cycle.
-- Access to host filesystem and networking via L2 bridging.
-- Full Linux userland including Python, etc.
-- Kernel bugs don't crash the system.
-
-
-How to use it
-=============
-
-cd <kerneldir>
-path/to/net_test/run_net_test.sh <test>
-
-where <test> is the name of a test binary in the net_test directory. This can
-be an x86 binary, a shell script, a Python script. etc.
-
-
-How it works
-============
-
-net_test compiles the kernel to a user-mode linux binary, which runs as a
-process on the host machine. It runs the binary to start a Linux "virtual
-machine" whose root filesystem is the supplied Debian disk image. The machine
-boots, mounts the root filesystem read-only, runs the specified test from init, and then drops to a shell.
-
-
-Access to host filesystem
-=========================
-
-The VM mounts the host filesystem at /host, so the test can be modified and
-re-run without rebooting the VM.
-
-
-Access to host networking
-=========================
-
-Access to host networking is provided by tap interfaces. On the host, the
-interfaces are named <user>TAP0, <user>TAP1, etc., where <user> is the first
-10 characters of the username running net_test. (10 characters because
-IFNAMSIZ = 16). On the guest, they are named eth0, eth1, etc.
-
-net_test does not do any networking setup beyond creating the tap interfaces.
-IP connectivity can be provided on the host side by setting up a DHCP server
-and NAT, sending IPv6 router advertisements, etc. By default, the VM has IPv6
-privacy addresses disabled, so its IPv6 addresses can be predicted using a tool
-such as ipv6calc.
-
-The provided filesystem contains a DHCPv4 client and simple networking
-utilities such as ping[6], traceroute[6], and wget.
-
-The number of tap interfaces is currently hardcoded to two. To change this
-number, modify run_net_test.sh.
-
-
-Logging into the VM, installing packages, etc.
-==============================================
-
-net_test mounts the root filesystem read-only, and runs the test from init, but
-since the filesystem contains a full Linux userland, it's possible to boot into
-userland and modify the filesystem, for example to install packages using
-apt-get install. Log in as root with no password. By default, the filesystem is
-configured to perform DHCPv4 on eth0 and listen to RAs.
-
-
-Bugs
-====
-
-Since the test mounts the filesystem read-only, tests cannot modify
-/etc/resolv.conf and the system resolver is hardcoded to 8.8.8.8.
diff --git a/tests/net_test/all_tests.sh b/tests/net_test/all_tests.sh
deleted file mode 100755
index ce147d3..0000000
--- a/tests/net_test/all_tests.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/bin/bash
-
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-readonly PREFIX="#####"
-
-function maybePlural() {
- # $1 = integer to use for plural check
- # $2 = singular string
- # $3 = plural string
- if [ $1 -ne 1 ]; then
- echo "$3"
- else
- echo "$2"
- fi
-}
-
-
-readonly tests=$(find . -name '*_test.py' -type f -executable)
-readonly count=$(echo $tests | wc -w)
-echo "$PREFIX Found $count $(maybePlural $count test tests)."
-
-exit_code=0
-
-i=0
-for test in $tests; do
- i=$((i + 1))
- echo ""
- echo "$PREFIX $test ($i/$count)"
- echo ""
- $test || exit_code=$(( exit_code + 1 ))
- echo ""
-done
-
-echo "$PREFIX $exit_code failed $(maybePlural $exit_code test tests)."
-exit $exit_code
diff --git a/tests/net_test/anycast_test.py b/tests/net_test/anycast_test.py
deleted file mode 100755
index 82130db..0000000
--- a/tests/net_test/anycast_test.py
+++ /dev/null
@@ -1,113 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-from socket import * # pylint: disable=wildcard-import
-import threading
-import time
-import unittest
-
-import cstruct
-import multinetwork_base
-import net_test
-
-IPV6_JOIN_ANYCAST = 27
-IPV6_LEAVE_ANYCAST = 28
-
-# pylint: disable=invalid-name
-IPv6Mreq = cstruct.Struct("IPv6Mreq", "=16si", "multiaddr ifindex")
-
-
-_CLOSE_HUNG = False
-
-
-def CauseOops():
- open("/proc/sysrq-trigger", "w").write("c")
-
-
-class CloseFileDescriptorThread(threading.Thread):
-
- def __init__(self, fd):
- super(CloseFileDescriptorThread, self).__init__()
- self.daemon = True
- self._fd = fd
- self.finished = False
-
- def run(self):
- global _CLOSE_HUNG
- _CLOSE_HUNG = True
- self._fd.close()
- _CLOSE_HUNG = False
- self.finished = True
-
-
-class AnycastTest(multinetwork_base.MultiNetworkBaseTest):
- """Tests for IPv6 anycast addresses.
-
- Relevant kernel commits:
- upstream net-next:
- 381f4dc ipv6: clean up anycast when an interface is destroyed
-
- android-3.10:
- 86a47ad ipv6: clean up anycast when an interface is destroyed
- """
- _TEST_NETID = 123
-
- def AnycastSetsockopt(self, s, is_add, netid, addr):
- ifindex = self.ifindices[netid]
- self.assertTrue(ifindex)
- ipv6mreq = IPv6Mreq((addr, ifindex))
- option = IPV6_JOIN_ANYCAST if is_add else IPV6_LEAVE_ANYCAST
- s.setsockopt(IPPROTO_IPV6, option, ipv6mreq.Pack())
-
- def testAnycastNetdeviceUnregister(self):
- netid = self._TEST_NETID
- self.assertNotIn(netid, self.tuns)
- self.tuns[netid] = self.CreateTunInterface(netid)
- self.SendRA(netid)
- iface = self.GetInterfaceName(netid)
- self.ifindices[netid] = net_test.GetInterfaceIndex(iface)
-
- s = socket(AF_INET6, SOCK_DGRAM, 0)
- addr = self.MyAddress(6, netid)
- self.assertIsNotNone(addr)
-
- addr = inet_pton(AF_INET6, addr)
- addr = addr[:8] + os.urandom(8)
- self.AnycastSetsockopt(s, True, netid, addr)
-
- # Close the tun fd in the background.
- # This will hang if the kernel has the bug.
- thread = CloseFileDescriptorThread(self.tuns[netid])
- thread.start()
- time.sleep(0.1)
-
- # Make teardown work.
- del self.tuns[netid]
- # Check that the interface is gone.
- try:
- self.assertIsNone(self.MyAddress(6, netid))
- finally:
- # This doesn't seem to help, but still.
- self.AnycastSetsockopt(s, False, netid, addr)
- self.assertTrue(thread.finished)
-
-
-if __name__ == "__main__":
- unittest.main(exit=False)
- if _CLOSE_HUNG:
- time.sleep(3)
- CauseOops()
diff --git a/tests/net_test/csocket.py b/tests/net_test/csocket.py
deleted file mode 100644
index 5dc495c..0000000
--- a/tests/net_test/csocket.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Python wrapper for C socket calls and data structures."""
-
-import ctypes
-import ctypes.util
-import os
-import socket
-import struct
-
-import cstruct
-
-
-# Data structures.
-# These aren't constants, they're classes. So, pylint: disable=invalid-name
-CMsgHdr = cstruct.Struct("cmsghdr", "@Lii", "len level type")
-Iovec = cstruct.Struct("iovec", "@LL", "base len")
-MsgHdr = cstruct.Struct("msghdr", "@LLLLLLi",
- "name namelen iov iovlen control msg_controllen flags")
-SockaddrIn = cstruct.Struct("sockaddr_in", "=HH4sxxxxxxxx", "family port addr")
-SockaddrIn6 = cstruct.Struct("sockaddr_in6", "=HHI16sI",
- "family port flowinfo addr scope_id")
-
-# Constants.
-CMSG_ALIGNTO = struct.calcsize("@L") # The kernel defines this as sizeof(long).
-MSG_CONFIRM = 0X800
-
-# Find the C library.
-libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
-
-
-def PaddedLength(length):
- return CMSG_ALIGNTO * ((length / CMSG_ALIGNTO) + (length % CMSG_ALIGNTO != 0))
-
-
-def MaybeRaiseSocketError(ret):
- if ret < 0:
- errno = ctypes.get_errno()
- raise socket.error(errno, os.strerror(errno))
-
-
-def Sockaddr(addr):
- if ":" in addr[0]:
- family = socket.AF_INET6
- if len(addr) == 4:
- addr, port, flowinfo, scope_id = addr
- else:
- (addr, port), flowinfo, scope_id = addr, 0, 0
- addr = socket.inet_pton(family, addr)
- return SockaddrIn6((family, socket.ntohs(port), socket.ntohl(flowinfo),
- addr, scope_id))
- else:
- family = socket.AF_INET
- addr, port = addr
- addr = socket.inet_pton(family, addr)
- return SockaddrIn((family, socket.ntohs(port), addr))
-
-
-def _MakeMsgControl(optlist):
- """Creates a msg_control blob from a list of cmsg attributes.
-
- Takes a list of cmsg attributes. Each attribute is a tuple of:
- - level: An integer, e.g., SOL_IPV6.
- - type: An integer, the option identifier, e.g., IPV6_HOPLIMIT.
- - data: The option data. This is either a string or an integer. If it's an
- integer it will be written as an unsigned integer in host byte order. If
- it's a string, it's used as is.
-
- Data is padded to an integer multiple of CMSG_ALIGNTO.
-
- Args:
- optlist: A list of tuples describing cmsg options.
-
- Returns:
- A string, a binary blob usable as the control data for a sendmsg call.
-
- Raises:
- TypeError: Option data is neither an integer nor a string.
- """
- msg_control = ""
-
- for i, opt in enumerate(optlist):
- msg_level, msg_type, data = opt
- if isinstance(data, int):
- data = struct.pack("=I", data)
- elif not isinstance(data, str):
- raise TypeError("unknown data type for opt %i: %s" % (i, type(data)))
-
- datalen = len(data)
- msg_len = len(CMsgHdr) + datalen
- padding = "\x00" * (PaddedLength(datalen) - datalen)
- msg_control += CMsgHdr((msg_len, msg_level, msg_type)).Pack()
- msg_control += data + padding
-
- return msg_control
-
-
-def Bind(s, to):
- """Python wrapper for connect."""
- ret = libc.bind(s.fileno(), to.CPointer(), len(to))
- MaybeRaiseSocketError(ret)
- return ret
-
-
-def Connect(s, to):
- """Python wrapper for connect."""
- ret = libc.connect(s.fileno(), to.CPointer(), len(to))
- MaybeRaiseSocketError(ret)
- return ret
-
-
-def Sendmsg(s, to, data, control, flags):
- """Python wrapper for sendmsg.
-
- Args:
- s: A Python socket object. Becomes sockfd.
- to: An address tuple, or a SockaddrIn[6] struct. Becomes msg->msg_name.
- data: A string, the data to write. Goes into msg->msg_iov.
- control: A list of cmsg options. Becomes msg->msg_control.
- flags: An integer. Becomes msg->msg_flags.
-
- Returns:
- If sendmsg succeeds, returns the number of bytes written as an integer.
-
- Raises:
- socket.error: If sendmsg fails.
- """
- # Create ctypes buffers and pointers from our structures. We need to hang on
- # to the underlying Python objects, because we don't want them to be garbage
- # collected and freed while we have C pointers to them.
-
- # Convert the destination address into a struct sockaddr.
- if to:
- if isinstance(to, tuple):
- to = Sockaddr(to)
- msg_name = to.CPointer()
- msg_namelen = len(to)
- else:
- msg_name = 0
- msg_namelen = 0
-
- # Convert the data to a data buffer and a struct iovec pointing at it.
- if data:
- databuf = ctypes.create_string_buffer(data)
- iov = Iovec((ctypes.addressof(databuf), len(data)))
- msg_iov = iov.CPointer()
- msg_iovlen = 1
- else:
- msg_iov = 0
- msg_iovlen = 0
-
- # Marshal the cmsg options.
- if control:
- control = _MakeMsgControl(control)
- controlbuf = ctypes.create_string_buffer(control)
- msg_control = ctypes.addressof(controlbuf)
- msg_controllen = len(control)
- else:
- msg_control = 0
- msg_controllen = 0
-
- # Assemble the struct msghdr.
- msghdr = MsgHdr((msg_name, msg_namelen, msg_iov, msg_iovlen,
- msg_control, msg_controllen, flags)).Pack()
-
- # Call sendmsg.
- ret = libc.sendmsg(s.fileno(), msghdr, 0)
- MaybeRaiseSocketError(ret)
-
- return ret
diff --git a/tests/net_test/cstruct.py b/tests/net_test/cstruct.py
deleted file mode 100644
index 91cd72e..0000000
--- a/tests/net_test/cstruct.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""A simple module for declaring C-like structures.
-
-Example usage:
-
->>> # Declare a struct type by specifying name, field formats and field names.
-... # Field formats are the same as those used in the struct module.
-... import cstruct
->>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
->>>
->>>
->>> # Create instances from tuples or raw bytes. Data past the end is ignored.
-... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
->>> print n1
-NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
->>>
->>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
-... "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
->>> print n2
-NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
->>>
->>> # Serialize to raw bytes.
-... print n1.Pack().encode("hex")
-2c0000002000020000000000eb010000
->>>
->>> # Parse the beginning of a byte stream as a struct, and return the struct
-... # and the remainder of the stream for further reading.
-... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
-... "\x00\x00\x00\x00\xfe\x01\x00\x00"
-... "more data")
->>> cstruct.Read(data, NLMsgHdr)
-(NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
->>>
-"""
-
-import ctypes
-import string
-import struct
-
-
-def CalcNumElements(fmt):
- size = struct.calcsize(fmt)
- elements = struct.unpack(fmt, "\x00" * size)
- return len(elements)
-
-
-def Struct(name, fmt, fieldnames, substructs={}):
- """Function that returns struct classes."""
-
- class Meta(type):
-
- def __len__(cls):
- return cls._length
-
- def __init__(cls, unused_name, unused_bases, namespace):
- # Make the class object have the name that's passed in.
- type.__init__(cls, namespace["_name"], unused_bases, namespace)
-
- class CStruct(object):
- """Class representing a C-like structure."""
-
- __metaclass__ = Meta
-
- # Name of the struct.
- _name = name
- # List of field names.
- _fieldnames = fieldnames
- # Dict mapping field indices to nested struct classes.
- _nested = {}
-
- if isinstance(_fieldnames, str):
- _fieldnames = _fieldnames.split(" ")
-
- # Parse fmt into _format, converting any S format characters to "XXs",
- # where XX is the length of the struct type's packed representation.
- _format = ""
- laststructindex = 0
- for i in xrange(len(fmt)):
- if fmt[i] == "S":
- # Nested struct. Record the index in our struct it should go into.
- index = CalcNumElements(fmt[:i])
- _nested[index] = substructs[laststructindex]
- laststructindex += 1
- _format += "%ds" % len(_nested[index])
- else:
- # Standard struct format character.
- _format += fmt[i]
-
- _length = struct.calcsize(_format)
-
- def _SetValues(self, values):
- super(CStruct, self).__setattr__("_values", list(values))
-
- def _Parse(self, data):
- data = data[:self._length]
- values = list(struct.unpack(self._format, data))
- for index, value in enumerate(values):
- if isinstance(value, str) and index in self._nested:
- values[index] = self._nested[index](value)
- self._SetValues(values)
-
- def __init__(self, values):
- # Initializing from a string.
- if isinstance(values, str):
- if len(values) < self._length:
- raise TypeError("%s requires string of length %d, got %d" %
- (self._name, self._length, len(values)))
- self._Parse(values)
- else:
- # Initializing from a tuple.
- if len(values) != len(self._fieldnames):
- raise TypeError("%s has exactly %d fieldnames (%d given)" %
- (self._name, len(self._fieldnames), len(values)))
- self._SetValues(values)
-
- def _FieldIndex(self, attr):
- try:
- return self._fieldnames.index(attr)
- except ValueError:
- raise AttributeError("'%s' has no attribute '%s'" %
- (self._name, attr))
-
- def __getattr__(self, name):
- return self._values[self._FieldIndex(name)]
-
- def __setattr__(self, name, value):
- self._values[self._FieldIndex(name)] = value
-
- @classmethod
- def __len__(cls):
- return cls._length
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __eq__(self, other):
- return (isinstance(other, self.__class__) and
- self._name == other._name and
- self._fieldnames == other._fieldnames and
- self._values == other._values)
-
- @staticmethod
- def _MaybePackStruct(value):
- if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
- return value.Pack()
- else:
- return value
-
- def Pack(self):
- values = [self._MaybePackStruct(v) for v in self._values]
- return struct.pack(self._format, *values)
-
- def __str__(self):
- def FieldDesc(index, name, value):
- if isinstance(value, str) and any(
- c not in string.printable for c in value):
- value = value.encode("hex")
- return "%s=%s" % (name, value)
-
- descriptions = [
- FieldDesc(i, n, v) for i, (n, v) in
- enumerate(zip(self._fieldnames, self._values))]
-
- return "%s(%s)" % (self._name, ", ".join(descriptions))
-
- def __repr__(self):
- return str(self)
-
- def CPointer(self):
- """Returns a C pointer to the serialized structure."""
- buf = ctypes.create_string_buffer(self.Pack())
- # Store the C buffer in the object so it doesn't get garbage collected.
- super(CStruct, self).__setattr__("_buffer", buf)
- return ctypes.addressof(self._buffer)
-
- return CStruct
-
-
-def Read(data, struct_type):
- length = len(struct_type)
- return struct_type(data), data[length:]
diff --git a/tests/net_test/cstruct_test.py b/tests/net_test/cstruct_test.py
deleted file mode 100755
index 2d5a408..0000000
--- a/tests/net_test/cstruct_test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2016 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import cstruct
-
-
-# These aren't constants, they're classes. So, pylint: disable=invalid-name
-TestStructA = cstruct.Struct("TestStructA", "=BI", "byte1 int2")
-TestStructB = cstruct.Struct("TestStructB", "=BI", "byte1 int2")
-
-
-class CstructTest(unittest.TestCase):
-
- def CheckEquals(self, a, b):
- self.assertEquals(a, b)
- self.assertEquals(b, a)
- assert a == b
- assert b == a
- assert not (a != b) # pylint: disable=g-comparison-negation,superfluous-parens
- assert not (b != a) # pylint: disable=g-comparison-negation,superfluous-parens
-
- def CheckNotEquals(self, a, b):
- self.assertNotEquals(a, b)
- self.assertNotEquals(b, a)
- assert a != b
- assert b != a
- assert not (a == b) # pylint: disable=g-comparison-negation,superfluous-parens
- assert not (b == a) # pylint: disable=g-comparison-negation,superfluous-parens
-
- def testEqAndNe(self):
- a1 = TestStructA((1, 2))
- a2 = TestStructA((2, 3))
- a3 = TestStructA((1, 2))
- b = TestStructB((1, 2))
- self.CheckNotEquals(a1, b)
- self.CheckNotEquals(a2, b)
- self.CheckNotEquals(a1, a2)
- self.CheckNotEquals(a2, a3)
- for i in [a1, a2, a3, b]:
- self.CheckEquals(i, i)
- self.CheckEquals(a1, a3)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/forwarding_test.py b/tests/net_test/forwarding_test.py
deleted file mode 100755
index 185e477..0000000
--- a/tests/net_test/forwarding_test.py
+++ /dev/null
@@ -1,109 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2015 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import itertools
-import random
-import unittest
-
-from socket import *
-
-import iproute
-import multinetwork_base
-import net_test
-import packets
-
-
-class ForwardingTest(multinetwork_base.MultiNetworkBaseTest):
-
- TCP_TIME_WAIT = 6
-
- def ForwardBetweenInterfaces(self, enabled, iface1, iface2):
- for iif, oif in itertools.permutations([iface1, iface2]):
- self.iproute.IifRule(6, enabled, self.GetInterfaceName(iif),
- self._TableForNetid(oif), self.PRIORITY_IIF)
-
- def setUp(self):
- self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
-
- def tearDown(self):
- self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
-
- def CheckForwardingCrash(self, netid, iface1, iface2):
- listenport = packets.RandomPort()
- listensocket = net_test.IPv6TCPSocket()
- listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
- listensocket.bind(("::", listenport))
- listensocket.listen(100)
- self.SetSocketMark(listensocket, netid)
-
- version = 6
- remoteaddr = self.GetRemoteAddress(version)
- myaddr = self.MyAddress(version, netid)
-
- desc, syn = packets.SYN(listenport, version, remoteaddr, myaddr)
- synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
- msg = "Sent %s, expected %s" % (desc, synack_desc)
- reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
-
- establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
- self.ReceivePacketOn(netid, establishing_ack)
- accepted, peer = listensocket.accept()
- remoteport = accepted.getpeername()[1]
-
- accepted.close()
- desc, fin = packets.FIN(version, myaddr, remoteaddr, establishing_ack)
- self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
-
- desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
- self.ReceivePacketOn(netid, finack)
-
- # Check our socket is now in TIME_WAIT.
- sockets = self.ReadProcNetSocket("tcp6")
- mysrc = "%s:%04X" % (net_test.FormatSockStatAddress(myaddr), listenport)
- mydst = "%s:%04X" % (net_test.FormatSockStatAddress(remoteaddr), remoteport)
- state = None
- sockets = [s for s in sockets if s[0] == mysrc and s[1] == mydst]
- self.assertEquals(1, len(sockets))
- self.assertEquals("%02X" % self.TCP_TIME_WAIT, sockets[0][2])
-
- # Remove our IP address.
- try:
- self.iproute.DelAddress(myaddr, 64, self.ifindices[netid])
-
- self.ReceivePacketOn(iface1, finack)
- self.ReceivePacketOn(iface1, establishing_ack)
- self.ReceivePacketOn(iface1, establishing_ack)
- # No crashes? Good.
-
- finally:
- # Put back our IP address.
- self.SendRA(netid)
- listensocket.close()
-
- def testCrash(self):
- # Run the test a few times as it doesn't crash/hang the first time.
- for netids in itertools.permutations(self.tuns):
- # Pick an interface to send traffic on and two to forward traffic between.
- netid, iface1, iface2 = random.sample(netids, 3)
- self.ForwardBetweenInterfaces(True, iface1, iface2)
- try:
- self.CheckForwardingCrash(netid, iface1, iface2)
- finally:
- self.ForwardBetweenInterfaces(False, iface1, iface2)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/iproute.py b/tests/net_test/iproute.py
deleted file mode 100644
index 2c63993..0000000
--- a/tests/net_test/iproute.py
+++ /dev/null
@@ -1,541 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Partial Python implementation of iproute functionality."""
-
-# pylint: disable=g-bad-todo
-
-import errno
-import os
-import socket
-import struct
-import sys
-
-import cstruct
-import netlink
-
-
-### Base netlink constants. See include/uapi/linux/netlink.h.
-NETLINK_ROUTE = 0
-
-# Request constants.
-NLM_F_REQUEST = 1
-NLM_F_ACK = 4
-NLM_F_REPLACE = 0x100
-NLM_F_EXCL = 0x200
-NLM_F_CREATE = 0x400
-NLM_F_DUMP = 0x300
-
-# Message types.
-NLMSG_ERROR = 2
-NLMSG_DONE = 3
-
-# Data structure formats.
-# These aren't constants, they're classes. So, pylint: disable=invalid-name
-NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
-NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
-NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
-
-# Alignment / padding.
-NLA_ALIGNTO = 4
-
-
-### rtnetlink constants. See include/uapi/linux/rtnetlink.h.
-# Message types.
-RTM_NEWLINK = 16
-RTM_DELLINK = 17
-RTM_GETLINK = 18
-RTM_NEWADDR = 20
-RTM_DELADDR = 21
-RTM_GETADDR = 22
-RTM_NEWROUTE = 24
-RTM_DELROUTE = 25
-RTM_GETROUTE = 26
-RTM_NEWNEIGH = 28
-RTM_DELNEIGH = 29
-RTM_GETNEIGH = 30
-RTM_NEWRULE = 32
-RTM_DELRULE = 33
-RTM_GETRULE = 34
-
-# Routing message type values (rtm_type).
-RTN_UNSPEC = 0
-RTN_UNICAST = 1
-RTN_UNREACHABLE = 7
-
-# Routing protocol values (rtm_protocol).
-RTPROT_UNSPEC = 0
-RTPROT_STATIC = 4
-
-# Route scope values (rtm_scope).
-RT_SCOPE_UNIVERSE = 0
-RT_SCOPE_LINK = 253
-
-# Named routing tables.
-RT_TABLE_UNSPEC = 0
-
-# Routing attributes.
-RTA_DST = 1
-RTA_SRC = 2
-RTA_OIF = 4
-RTA_GATEWAY = 5
-RTA_PRIORITY = 6
-RTA_PREFSRC = 7
-RTA_METRICS = 8
-RTA_CACHEINFO = 12
-RTA_TABLE = 15
-RTA_MARK = 16
-RTA_UID = 18
-
-# Route metric attributes.
-RTAX_MTU = 2
-RTAX_HOPLIMIT = 10
-
-# Data structure formats.
-IfinfoMsg = cstruct.Struct(
- "IfinfoMsg", "=BBHiII", "family pad type index flags change")
-RTMsg = cstruct.Struct(
- "RTMsg", "=BBBBBBBBI",
- "family dst_len src_len tos table protocol scope type flags")
-RTACacheinfo = cstruct.Struct(
- "RTACacheinfo", "=IIiiI", "clntref lastuse expires error used")
-
-
-### Interface address constants. See include/uapi/linux/if_addr.h.
-# Interface address attributes.
-IFA_ADDRESS = 1
-IFA_LOCAL = 2
-IFA_CACHEINFO = 6
-
-# Address flags.
-IFA_F_SECONDARY = 0x01
-IFA_F_TEMPORARY = IFA_F_SECONDARY
-IFA_F_NODAD = 0x02
-IFA_F_OPTIMISTIC = 0x04
-IFA_F_DADFAILED = 0x08
-IFA_F_HOMEADDRESS = 0x10
-IFA_F_DEPRECATED = 0x20
-IFA_F_TENTATIVE = 0x40
-IFA_F_PERMANENT = 0x80
-
-# Data structure formats.
-IfAddrMsg = cstruct.Struct(
- "IfAddrMsg", "=BBBBI",
- "family prefixlen flags scope index")
-IFACacheinfo = cstruct.Struct(
- "IFACacheinfo", "=IIII", "prefered valid cstamp tstamp")
-NDACacheinfo = cstruct.Struct(
- "NDACacheinfo", "=IIII", "confirmed used updated refcnt")
-
-
-### Neighbour table entry constants. See include/uapi/linux/neighbour.h.
-# Neighbour cache entry attributes.
-NDA_DST = 1
-NDA_LLADDR = 2
-NDA_CACHEINFO = 3
-NDA_PROBES = 4
-
-# Neighbour cache entry states.
-NUD_PERMANENT = 0x80
-
-# Data structure formats.
-NdMsg = cstruct.Struct(
- "NdMsg", "=BxxxiHBB",
- "family ifindex state flags type")
-
-
-### FIB rule constants. See include/uapi/linux/fib_rules.h.
-FRA_IIFNAME = 3
-FRA_PRIORITY = 6
-FRA_FWMARK = 10
-FRA_SUPPRESS_PREFIXLEN = 14
-FRA_TABLE = 15
-FRA_FWMASK = 16
-FRA_OIFNAME = 17
-FRA_UID_START = 18
-FRA_UID_END = 19
-
-
-# Link constants. See include/uapi/linux/if_link.h.
-IFLA_ADDRESS = 1
-IFLA_BROADCAST = 2
-IFLA_IFNAME = 3
-IFLA_MTU = 4
-IFLA_QDISC = 6
-IFLA_STATS = 7
-IFLA_TXQLEN = 13
-IFLA_MAP = 14
-IFLA_OPERSTATE = 16
-IFLA_LINKMODE = 17
-IFLA_STATS64 = 23
-IFLA_AF_SPEC = 26
-IFLA_GROUP = 27
-IFLA_EXT_MASK = 29
-IFLA_PROMISCUITY = 30
-IFLA_NUM_TX_QUEUES = 31
-IFLA_NUM_RX_QUEUES = 32
-IFLA_CARRIER = 33
-
-
-def CommandVerb(command):
- return ["NEW", "DEL", "GET", "SET"][command % 4]
-
-
-def CommandSubject(command):
- return ["LINK", "ADDR", "ROUTE", "NEIGH", "RULE"][(command - 16) / 4]
-
-
-def CommandName(command):
- try:
- return "RTM_%s%s" % (CommandVerb(command), CommandSubject(command))
- except IndexError:
- return "RTM_%d" % command
-
-
-class IPRoute(netlink.NetlinkSocket):
- """Provides a tiny subset of iproute functionality."""
-
- FAMILY = NETLINK_ROUTE
-
- def _NlAttrIPAddress(self, nla_type, family, address):
- return self._NlAttr(nla_type, socket.inet_pton(family, address))
-
- def _NlAttrInterfaceName(self, nla_type, interface):
- return self._NlAttr(nla_type, interface + "\x00")
-
- def _GetConstantName(self, value, prefix):
- return super(IPRoute, self)._GetConstantName(__name__, value, prefix)
-
- def _Decode(self, command, msg, nla_type, nla_data):
- """Decodes netlink attributes to Python types.
-
- Values for which the code knows the type (e.g., the fwmark ID in a
- RTM_NEWRULE command) are decoded to Python integers, strings, etc. Values
- of unknown type are returned as raw byte strings.
-
- Args:
- command: An integer.
- - If positive, the number of the rtnetlink command being carried out.
- This is used to interpret the attributes. For example, for an
- RTM_NEWROUTE command, attribute type 3 is the incoming interface and
- is an integer, but for a RTM_NEWRULE command, attribute type 3 is the
- incoming interface name and is a string.
- - If negative, one of the following (negative) values:
- - RTA_METRICS: Interpret as nested route metrics.
- family: The address family. Used to convert IP addresses into strings.
- nla_type: An integer, then netlink attribute type.
- nla_data: A byte string, the netlink attribute data.
-
- Returns:
- A tuple (name, data):
- - name is a string (e.g., "FRA_PRIORITY") if we understood the attribute,
- or an integer if we didn't.
- - data can be an integer, a string, a nested dict of attributes as
- returned by _ParseAttributes (e.g., for RTA_METRICS), a cstruct.Struct
- (e.g., RTACacheinfo), etc. If we didn't understand the attribute, it
- will be the raw byte string.
- """
- if command == -RTA_METRICS:
- name = self._GetConstantName(nla_type, "RTAX_")
- elif CommandSubject(command) == "ADDR":
- name = self._GetConstantName(nla_type, "IFA_")
- elif CommandSubject(command) == "LINK":
- name = self._GetConstantName(nla_type, "IFLA_")
- elif CommandSubject(command) == "RULE":
- name = self._GetConstantName(nla_type, "FRA_")
- elif CommandSubject(command) == "ROUTE":
- name = self._GetConstantName(nla_type, "RTA_")
- elif CommandSubject(command) == "NEIGH":
- name = self._GetConstantName(nla_type, "NDA_")
- else:
- # Don't know what this is. Leave it as an integer.
- name = nla_type
-
- if name in ["FRA_PRIORITY", "FRA_FWMARK", "FRA_TABLE", "FRA_FWMASK",
- "FRA_UID_START", "FRA_UID_END",
- "RTA_OIF", "RTA_PRIORITY", "RTA_TABLE", "RTA_MARK",
- "IFLA_MTU", "IFLA_TXQLEN", "IFLA_GROUP", "IFLA_EXT_MASK",
- "IFLA_PROMISCUITY", "IFLA_NUM_RX_QUEUES",
- "IFLA_NUM_TX_QUEUES", "NDA_PROBES", "RTAX_MTU",
- "RTAX_HOPLIMIT"]:
- data = struct.unpack("=I", nla_data)[0]
- elif name == "FRA_SUPPRESS_PREFIXLEN":
- data = struct.unpack("=i", nla_data)[0]
- elif name in ["IFLA_LINKMODE", "IFLA_OPERSTATE", "IFLA_CARRIER"]:
- data = ord(nla_data)
- elif name in ["IFA_ADDRESS", "IFA_LOCAL", "RTA_DST", "RTA_SRC",
- "RTA_GATEWAY", "RTA_PREFSRC", "RTA_UID",
- "NDA_DST"]:
- data = socket.inet_ntop(msg.family, nla_data)
- elif name in ["FRA_IIFNAME", "FRA_OIFNAME", "IFLA_IFNAME", "IFLA_QDISC"]:
- data = nla_data.strip("\x00")
- elif name == "RTA_METRICS":
- data = self._ParseAttributes(-RTA_METRICS, msg.family, None, nla_data)
- elif name == "RTA_CACHEINFO":
- data = RTACacheinfo(nla_data)
- elif name == "IFA_CACHEINFO":
- data = IFACacheinfo(nla_data)
- elif name == "NDA_CACHEINFO":
- data = NDACacheinfo(nla_data)
- elif name in ["NDA_LLADDR", "IFLA_ADDRESS"]:
- data = ":".join(x.encode("hex") for x in nla_data)
- else:
- data = nla_data
-
- return name, data
-
- def __init__(self):
- super(IPRoute, self).__init__()
-
- def _AddressFamily(self, version):
- return {4: socket.AF_INET, 6: socket.AF_INET6}[version]
-
- def _SendNlRequest(self, command, data, flags=0):
- """Sends a netlink request and expects an ack."""
-
- flags |= NLM_F_REQUEST
- if CommandVerb(command) != "GET":
- flags |= NLM_F_ACK
- if CommandVerb(command) == "NEW":
- if not flags & NLM_F_REPLACE:
- flags |= (NLM_F_EXCL | NLM_F_CREATE)
-
- super(IPRoute, self)._SendNlRequest(command, data, flags)
-
- def _Rule(self, version, is_add, rule_type, table, match_nlattr, priority):
- """Python equivalent of "ip rule <add|del> <match_cond> lookup <table>".
-
- Args:
- version: An integer, 4 or 6.
- is_add: True to add a rule, False to delete it.
- rule_type: Type of rule, e.g., RTN_UNICAST or RTN_UNREACHABLE.
- table: If nonzero, rule looks up this table.
- match_nlattr: A blob of struct nlattrs that express the match condition.
- If None, match everything.
- priority: An integer, the priority.
-
- Raises:
- IOError: If the netlink request returns an error.
- ValueError: If the kernel's response could not be parsed.
- """
- # Create a struct rtmsg specifying the table and the given match attributes.
- family = self._AddressFamily(version)
- rtmsg = RTMsg((family, 0, 0, 0, RT_TABLE_UNSPEC,
- RTPROT_STATIC, RT_SCOPE_UNIVERSE, rule_type, 0)).Pack()
- rtmsg += self._NlAttrU32(FRA_PRIORITY, priority)
- if match_nlattr:
- rtmsg += match_nlattr
- if table:
- rtmsg += self._NlAttrU32(FRA_TABLE, table)
-
- # Create a netlink request containing the rtmsg.
- command = RTM_NEWRULE if is_add else RTM_DELRULE
- self._SendNlRequest(command, rtmsg)
-
- def DeleteRulesAtPriority(self, version, priority):
- family = self._AddressFamily(version)
- rtmsg = RTMsg((family, 0, 0, 0, RT_TABLE_UNSPEC,
- RTPROT_STATIC, RT_SCOPE_UNIVERSE, RTN_UNICAST, 0)).Pack()
- rtmsg += self._NlAttrU32(FRA_PRIORITY, priority)
- while True:
- try:
- self._SendNlRequest(RTM_DELRULE, rtmsg)
- except IOError, e:
- if e.errno == -errno.ENOENT:
- break
- else:
- raise
-
- def FwmarkRule(self, version, is_add, fwmark, table, priority):
- nlattr = self._NlAttrU32(FRA_FWMARK, fwmark)
- return self._Rule(version, is_add, RTN_UNICAST, table, nlattr, priority)
-
- def IifRule(self, version, is_add, iif, table, priority):
- nlattr = self._NlAttrInterfaceName(FRA_IIFNAME, iif)
- return self._Rule(version, is_add, RTN_UNICAST, table, nlattr, priority)
-
- def OifRule(self, version, is_add, oif, table, priority):
- nlattr = self._NlAttrInterfaceName(FRA_OIFNAME, oif)
- return self._Rule(version, is_add, RTN_UNICAST, table, nlattr, priority)
-
- def UidRangeRule(self, version, is_add, start, end, table, priority):
- nlattr = (self._NlAttrInterfaceName(FRA_IIFNAME, "lo") +
- self._NlAttrU32(FRA_UID_START, start) +
- self._NlAttrU32(FRA_UID_END, end))
- return self._Rule(version, is_add, RTN_UNICAST, table, nlattr, priority)
-
- def UnreachableRule(self, version, is_add, priority):
- return self._Rule(version, is_add, RTN_UNREACHABLE, None, None, priority)
-
- def DefaultRule(self, version, is_add, table, priority):
- return self.FwmarkRule(version, is_add, 0, table, priority)
-
- def CommandToString(self, command, data):
- try:
- name = CommandName(command)
- subject = CommandSubject(command)
- struct_type = {
- "ADDR": IfAddrMsg,
- "LINK": IfinfoMsg,
- "NEIGH": NdMsg,
- "ROUTE": RTMsg,
- "RULE": RTMsg,
- }[subject]
- parsed = self._ParseNLMsg(data, struct_type)
- return "%s %s" % (name, str(parsed))
- except IndexError:
- raise ValueError("Don't know how to print command type %s" % name)
-
- def MaybeDebugCommand(self, command, data):
- subject = CommandSubject(command)
- if "ALL" not in self.NL_DEBUG and subject not in self.NL_DEBUG:
- return
- print self.CommandToString(command, data)
-
- def MaybeDebugMessage(self, message):
- hdr = NLMsgHdr(message)
- self.MaybeDebugCommand(hdr.type, message)
-
- def PrintMessage(self, message):
- hdr = NLMsgHdr(message)
- print self.CommandToString(hdr.type, message)
-
- def DumpRules(self, version):
- """Returns the IP rules for the specified IP version."""
- # Create a struct rtmsg specifying the table and the given match attributes.
- family = self._AddressFamily(version)
- rtmsg = RTMsg((family, 0, 0, 0, 0, 0, 0, 0, 0))
- return self._Dump(RTM_GETRULE, rtmsg, RTMsg, "")
-
- def DumpLinks(self):
- ifinfomsg = IfinfoMsg((0, 0, 0, 0, 0, 0))
- return self._Dump(RTM_GETLINK, ifinfomsg, IfinfoMsg, "")
-
- def _Address(self, version, command, addr, prefixlen, flags, scope, ifindex):
- """Adds or deletes an IP address."""
- family = self._AddressFamily(version)
- ifaddrmsg = IfAddrMsg((family, prefixlen, flags, scope, ifindex)).Pack()
- ifaddrmsg += self._NlAttrIPAddress(IFA_ADDRESS, family, addr)
- if version == 4:
- ifaddrmsg += self._NlAttrIPAddress(IFA_LOCAL, family, addr)
- self._SendNlRequest(command, ifaddrmsg)
-
- def AddAddress(self, address, prefixlen, ifindex):
- self._Address(6 if ":" in address else 4,
- RTM_NEWADDR, address, prefixlen,
- IFA_F_PERMANENT, RT_SCOPE_UNIVERSE, ifindex)
-
- def DelAddress(self, address, prefixlen, ifindex):
- self._Address(6 if ":" in address else 4,
- RTM_DELADDR, address, prefixlen, 0, 0, ifindex)
-
- def GetAddress(self, address, ifindex=0):
- """Returns an ifaddrmsg for the requested address."""
- if ":" not in address:
- # The address is likely an IPv4 address. RTM_GETADDR without the
- # NLM_F_DUMP flag is not supported by the kernel. We do not currently
- # implement parsing dump results.
- raise NotImplementedError("IPv4 RTM_GETADDR not implemented.")
- self._Address(6, RTM_GETADDR, address, 0, 0, RT_SCOPE_UNIVERSE, ifindex)
- return self._GetMsg(IfAddrMsg)
-
- def _Route(self, version, command, table, dest, prefixlen, nexthop, dev,
- mark, uid):
- """Adds, deletes, or queries a route."""
- family = self._AddressFamily(version)
- scope = RT_SCOPE_UNIVERSE if nexthop else RT_SCOPE_LINK
- rtmsg = RTMsg((family, prefixlen, 0, 0, RT_TABLE_UNSPEC,
- RTPROT_STATIC, scope, RTN_UNICAST, 0)).Pack()
- if command == RTM_NEWROUTE and not table:
- # Don't allow setting routes in table 0, since its behaviour is confusing
- # and differs between IPv4 and IPv6.
- raise ValueError("Cowardly refusing to add a route to table 0")
- if table:
- rtmsg += self._NlAttrU32(FRA_TABLE, table)
- if dest != "default": # The default is the default route.
- rtmsg += self._NlAttrIPAddress(RTA_DST, family, dest)
- if nexthop:
- rtmsg += self._NlAttrIPAddress(RTA_GATEWAY, family, nexthop)
- if dev:
- rtmsg += self._NlAttrU32(RTA_OIF, dev)
- if mark is not None:
- rtmsg += self._NlAttrU32(RTA_MARK, mark)
- if uid is not None:
- rtmsg += self._NlAttrU32(RTA_UID, uid)
- self._SendNlRequest(command, rtmsg)
-
- def AddRoute(self, version, table, dest, prefixlen, nexthop, dev):
- self._Route(version, RTM_NEWROUTE, table, dest, prefixlen, nexthop, dev,
- None, None)
-
- def DelRoute(self, version, table, dest, prefixlen, nexthop, dev):
- self._Route(version, RTM_DELROUTE, table, dest, prefixlen, nexthop, dev,
- None, None)
-
- def GetRoutes(self, dest, oif, mark, uid):
- version = 6 if ":" in dest else 4
- prefixlen = {4: 32, 6: 128}[version]
- self._Route(version, RTM_GETROUTE, 0, dest, prefixlen, None, oif, mark, uid)
- data = self._Recv()
- # The response will either be an error or a list of routes.
- if NLMsgHdr(data).type == NLMSG_ERROR:
- self._ParseAck(data)
- routes = self._GetMsgList(RTMsg, data, False)
- return routes
-
- def _Neighbour(self, version, is_add, addr, lladdr, dev, state, flags=0):
- """Adds or deletes a neighbour cache entry."""
- family = self._AddressFamily(version)
-
- # Convert the link-layer address to a raw byte string.
- if is_add and lladdr:
- lladdr = lladdr.split(":")
- if len(lladdr) != 6:
- raise ValueError("Invalid lladdr %s" % ":".join(lladdr))
- lladdr = "".join(chr(int(hexbyte, 16)) for hexbyte in lladdr)
-
- ndmsg = NdMsg((family, dev, state, 0, RTN_UNICAST)).Pack()
- ndmsg += self._NlAttrIPAddress(NDA_DST, family, addr)
- if is_add and lladdr:
- ndmsg += self._NlAttr(NDA_LLADDR, lladdr)
- command = RTM_NEWNEIGH if is_add else RTM_DELNEIGH
- self._SendNlRequest(command, ndmsg, flags)
-
- def AddNeighbour(self, version, addr, lladdr, dev):
- self._Neighbour(version, True, addr, lladdr, dev, NUD_PERMANENT)
-
- def DelNeighbour(self, version, addr, lladdr, dev):
- self._Neighbour(version, False, addr, lladdr, dev, 0)
-
- def UpdateNeighbour(self, version, addr, lladdr, dev, state):
- self._Neighbour(version, True, addr, lladdr, dev, state,
- flags=NLM_F_REPLACE)
-
- def DumpNeighbours(self, version):
- ndmsg = NdMsg((self._AddressFamily(version), 0, 0, 0, 0))
- return self._Dump(RTM_GETNEIGH, ndmsg, NdMsg, "")
-
- def ParseNeighbourMessage(self, msg):
- msg, _ = self._ParseNLMsg(msg, NdMsg)
- return msg
-
-
-if __name__ == "__main__":
- iproute = IPRoute()
- iproute.DEBUG = True
- iproute.DumpRules(6)
- iproute.DumpLinks()
- print iproute.GetRoutes("2001:4860:4860::8888", 0, 0, None)
diff --git a/tests/net_test/multinetwork_base.py b/tests/net_test/multinetwork_base.py
deleted file mode 100644
index 31fcc4c..0000000
--- a/tests/net_test/multinetwork_base.py
+++ /dev/null
@@ -1,642 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Base module for multinetwork tests."""
-
-import errno
-import fcntl
-import os
-import posix
-import random
-import re
-from socket import * # pylint: disable=wildcard-import
-import struct
-
-from scapy import all as scapy
-
-import csocket
-import cstruct
-import iproute
-import net_test
-
-
-IFF_TUN = 1
-IFF_TAP = 2
-IFF_NO_PI = 0x1000
-TUNSETIFF = 0x400454ca
-
-SO_BINDTODEVICE = 25
-
-# Setsockopt values.
-IP_UNICAST_IF = 50
-IPV6_MULTICAST_IF = 17
-IPV6_UNICAST_IF = 76
-
-# Cmsg values.
-IP_TTL = 2
-IP_PKTINFO = 8
-IPV6_2292PKTOPTIONS = 6
-IPV6_FLOWINFO = 11
-IPV6_PKTINFO = 50
-IPV6_HOPLIMIT = 52 # Different from IPV6_UNICAST_HOPS, this is cmsg only.
-
-# Data structures.
-# These aren't constants, they're classes. So, pylint: disable=invalid-name
-InPktinfo = cstruct.Struct("in_pktinfo", "@i4s4s", "ifindex spec_dst addr")
-In6Pktinfo = cstruct.Struct("in6_pktinfo", "@16si", "addr ifindex")
-
-
-def HaveUidRouting():
- """Checks whether the kernel supports UID routing."""
- # Create a rule with the UID range selector. If the kernel doesn't understand
- # the selector, it will create a rule with no selectors.
- try:
- iproute.IPRoute().UidRangeRule(6, True, 1000, 2000, 100, 10000)
- except IOError:
- return False
-
- # Dump all the rules. If we find a rule using the UID range selector, then the
- # kernel supports UID range routing.
- rules = iproute.IPRoute().DumpRules(6)
- result = any("FRA_UID_START" in attrs for rule, attrs in rules)
-
- # Delete the rule.
- iproute.IPRoute().UidRangeRule(6, False, 1000, 2000, 100, 10000)
- return result
-
-AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
-
-HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
-HAVE_UID_ROUTING = HaveUidRouting()
-
-
-class UnexpectedPacketError(AssertionError):
- pass
-
-
-def MakePktInfo(version, addr, ifindex):
- family = {4: AF_INET, 6: AF_INET6}[version]
- if not addr:
- addr = {4: "0.0.0.0", 6: "::"}[version]
- if addr:
- addr = inet_pton(family, addr)
- if version == 6:
- return In6Pktinfo((addr, ifindex)).Pack()
- else:
- return InPktinfo((ifindex, addr, "\x00" * 4)).Pack()
-
-
-class MultiNetworkBaseTest(net_test.NetworkTest):
- """Base class for all multinetwork tests.
-
- This class does not contain any test code, but contains code to set up and
- tear a multi-network environment using multiple tun interfaces. The
- environment is designed to be similar to a real Android device in terms of
- rules and routes, and supports IPv4 and IPv6.
-
- Tests wishing to use this environment should inherit from this class and
- ensure that any setupClass, tearDownClass, setUp, and tearDown methods they
- implement also call the superclass versions.
- """
-
- # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
- NETIDS = [100, 150, 200, 250]
-
- # Stores sysctl values to write back when the test completes.
- saved_sysctls = {}
-
- # Wether to output setup commands.
- DEBUG = False
-
- # The size of our UID ranges.
- UID_RANGE_SIZE = 1000
-
- # Rule priorities.
- PRIORITY_UID = 100
- PRIORITY_OIF = 200
- PRIORITY_FWMARK = 300
- PRIORITY_IIF = 400
- PRIORITY_DEFAULT = 999
- PRIORITY_UNREACHABLE = 1000
-
- # For convenience.
- IPV4_ADDR = net_test.IPV4_ADDR
- IPV6_ADDR = net_test.IPV6_ADDR
- IPV4_PING = net_test.IPV4_PING
- IPV6_PING = net_test.IPV6_PING
-
- @classmethod
- def UidRangeForNetid(cls, netid):
- return (
- cls.UID_RANGE_SIZE * netid,
- cls.UID_RANGE_SIZE * (netid + 1) - 1
- )
-
- @classmethod
- def UidForNetid(cls, netid):
- return random.randint(*cls.UidRangeForNetid(netid))
-
- @classmethod
- def _TableForNetid(cls, netid):
- if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
- return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
- else:
- return netid
-
- @staticmethod
- def GetInterfaceName(netid):
- return "nettest%d" % netid
-
- @staticmethod
- def RouterMacAddress(netid):
- return "02:00:00:00:%02x:00" % netid
-
- @staticmethod
- def MyMacAddress(netid):
- return "02:00:00:00:%02x:01" % netid
-
- @staticmethod
- def _RouterAddress(netid, version):
- if version == 6:
- return "fe80::%02x00" % netid
- elif version == 4:
- return "10.0.%d.1" % netid
- else:
- raise ValueError("Don't support IPv%s" % version)
-
- @classmethod
- def _MyIPv4Address(cls, netid):
- return "10.0.%d.2" % netid
-
- @classmethod
- def _MyIPv6Address(cls, netid):
- return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
-
- @classmethod
- def MyAddress(cls, version, netid):
- return {4: cls._MyIPv4Address(netid),
- 5: "::ffff:" + cls._MyIPv4Address(netid),
- 6: cls._MyIPv6Address(netid)}[version]
-
- @classmethod
- def MyLinkLocalAddress(cls, netid):
- return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True)
-
- @staticmethod
- def IPv6Prefix(netid):
- return "2001:db8:%02x::" % netid
-
- @staticmethod
- def GetRandomDestination(prefix):
- if "." in prefix:
- return prefix + "%d.%d" % (random.randint(0, 31), random.randint(0, 255))
- else:
- return prefix + "%x:%x" % (random.randint(0, 65535),
- random.randint(0, 65535))
-
- def GetProtocolFamily(self, version):
- return {4: AF_INET, 6: AF_INET6}[version]
-
- @classmethod
- def CreateTunInterface(cls, netid):
- iface = cls.GetInterfaceName(netid)
- f = open("/dev/net/tun", "r+b")
- ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
- ifr += "\x00" * (40 - len(ifr))
- fcntl.ioctl(f, TUNSETIFF, ifr)
- # Give ourselves a predictable MAC address.
- net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
- # Disable DAD so we don't have to wait for it.
- cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
- # Set accept_ra to 2, because that's what we use.
- cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2)
- net_test.SetInterfaceUp(iface)
- net_test.SetNonBlocking(f)
- return f
-
- @classmethod
- def SendRA(cls, netid, retranstimer=None, reachabletime=0):
- validity = 300 # seconds
- macaddr = cls.RouterMacAddress(netid)
- lladdr = cls._RouterAddress(netid, 6)
-
- if retranstimer is None:
- # If no retrans timer was specified, pick one that's as long as the
- # router lifetime. This ensures that no spurious ND retransmits
- # will interfere with test expectations.
- retranstimer = validity
-
- # We don't want any routes in the main table. If the kernel doesn't support
- # putting RA routes into per-interface tables, configure routing manually.
- routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0
-
- ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
- scapy.IPv6(src=lladdr, hlim=255) /
- scapy.ICMPv6ND_RA(reachabletime=reachabletime,
- retranstimer=retranstimer,
- routerlifetime=routerlifetime) /
- scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
- scapy.ICMPv6NDOptPrefixInfo(prefix=cls.IPv6Prefix(netid),
- prefixlen=64,
- L=1, A=1,
- validlifetime=validity,
- preferredlifetime=validity))
- posix.write(cls.tuns[netid].fileno(), str(ra))
-
- @classmethod
- def _RunSetupCommands(cls, netid, is_add):
- for version in [4, 6]:
- # Find out how to configure things.
- iface = cls.GetInterfaceName(netid)
- ifindex = cls.ifindices[netid]
- macaddr = cls.RouterMacAddress(netid)
- router = cls._RouterAddress(netid, version)
- table = cls._TableForNetid(netid)
-
- # Set up routing rules.
- if HAVE_UID_ROUTING:
- start, end = cls.UidRangeForNetid(netid)
- cls.iproute.UidRangeRule(version, is_add, start, end, table,
- cls.PRIORITY_UID)
- cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
- cls.iproute.FwmarkRule(version, is_add, netid, table,
- cls.PRIORITY_FWMARK)
-
- # Configure routing and addressing.
- #
- # IPv6 uses autoconf for everything, except if per-device autoconf routing
- # tables are not supported, in which case the default route (only) is
- # configured manually. For IPv4 we have to manually configure addresses,
- # routes, and neighbour cache entries (since we don't reply to ARP or ND).
- #
- # Since deleting addresses also causes routes to be deleted, we need to
- # be careful with ordering or the delete commands will fail with ENOENT.
- do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
- if is_add:
- if version == 4:
- cls.iproute.AddAddress(cls._MyIPv4Address(netid), 24, ifindex)
- cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
- if do_routing:
- cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
- if version == 6:
- cls.iproute.AddRoute(version, table,
- cls.IPv6Prefix(netid), 64, None, ifindex)
- else:
- if do_routing:
- cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
- if version == 6:
- cls.iproute.DelRoute(version, table,
- cls.IPv6Prefix(netid), 64, None, ifindex)
- if version == 4:
- cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
- cls.iproute.DelAddress(cls._MyIPv4Address(netid), 24, ifindex)
-
- @classmethod
- def SetDefaultNetwork(cls, netid):
- table = cls._TableForNetid(netid) if netid else None
- for version in [4, 6]:
- is_add = table is not None
- cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT)
-
- @classmethod
- def ClearDefaultNetwork(cls):
- cls.SetDefaultNetwork(None)
-
- @classmethod
- def GetSysctl(cls, sysctl):
- return open(sysctl, "r").read()
-
- @classmethod
- def SetSysctl(cls, sysctl, value):
- # Only save each sysctl value the first time we set it. This is so we can
- # set it to arbitrary values multiple times and still write it back
- # correctly at the end.
- if sysctl not in cls.saved_sysctls:
- cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
- open(sysctl, "w").write(str(value) + "\n")
-
- @classmethod
- def SetIPv6SysctlOnAllIfaces(cls, sysctl, value):
- for netid in cls.tuns:
- iface = cls.GetInterfaceName(netid)
- name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl)
- cls.SetSysctl(name, value)
-
- @classmethod
- def _RestoreSysctls(cls):
- for sysctl, value in cls.saved_sysctls.iteritems():
- try:
- open(sysctl, "w").write(value)
- except IOError:
- pass
-
- @classmethod
- def _ICMPRatelimitFilename(cls, version):
- return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
- 6: "ipv6/icmp/ratelimit"}[version]
-
- @classmethod
- def _SetICMPRatelimit(cls, version, limit):
- cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
-
- @classmethod
- def setUpClass(cls):
- # This is per-class setup instead of per-testcase setup because shelling out
- # to ip and iptables is slow, and because routing configuration doesn't
- # change during the test.
- cls.iproute = iproute.IPRoute()
- cls.tuns = {}
- cls.ifindices = {}
- if HAVE_AUTOCONF_TABLE:
- cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
- cls.AUTOCONF_TABLE_OFFSET = -1000
- else:
- cls.AUTOCONF_TABLE_OFFSET = None
-
- # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
- for version in [4, 6]:
- cls._SetICMPRatelimit(version, 0)
-
- for netid in cls.NETIDS:
- cls.tuns[netid] = cls.CreateTunInterface(netid)
- iface = cls.GetInterfaceName(netid)
- cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
-
- cls.SendRA(netid)
- cls._RunSetupCommands(netid, True)
-
- for version in [4, 6]:
- cls.iproute.UnreachableRule(version, True, 1000)
-
- # Uncomment to look around at interface and rule configuration while
- # running in the background. (Once the test finishes running, all the
- # interfaces and rules are gone.)
- # time.sleep(30)
-
- @classmethod
- def tearDownClass(cls):
- for version in [4, 6]:
- try:
- cls.iproute.UnreachableRule(version, False, 1000)
- except IOError:
- pass
-
- for netid in cls.tuns:
- cls._RunSetupCommands(netid, False)
- cls.tuns[netid].close()
- cls._RestoreSysctls()
-
- def setUp(self):
- self.ClearTunQueues()
-
- def SetSocketMark(self, s, netid):
- if netid is None:
- netid = 0
- s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
-
- def GetSocketMark(self, s):
- return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
-
- def ClearSocketMark(self, s):
- self.SetSocketMark(s, 0)
-
- def BindToDevice(self, s, iface):
- if not iface:
- iface = ""
- s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
-
- def SetUnicastInterface(self, s, ifindex):
- # Otherwise, Python thinks it's a 1-byte option.
- ifindex = struct.pack("!I", ifindex)
-
- # Always set the IPv4 interface, because it will be used even on IPv6
- # sockets if the destination address is a mapped address.
- s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
- if s.family == AF_INET6:
- s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
-
- def GetRemoteAddress(self, version):
- return {4: self.IPV4_ADDR,
- 5: "::ffff:" + self.IPV4_ADDR,
- 6: self.IPV6_ADDR}[version]
-
- def SelectInterface(self, s, netid, mode):
- if mode == "uid":
- raise ValueError("Can't change UID on an existing socket")
- elif mode == "mark":
- self.SetSocketMark(s, netid)
- elif mode == "oif":
- iface = self.GetInterfaceName(netid) if netid else ""
- self.BindToDevice(s, iface)
- elif mode == "ucast_oif":
- self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
- else:
- raise ValueError("Unknown interface selection mode %s" % mode)
-
- def BuildSocket(self, version, constructor, netid, routing_mode):
- s = constructor(self.GetProtocolFamily(version))
-
- if routing_mode not in [None, "uid"]:
- self.SelectInterface(s, netid, routing_mode)
- elif routing_mode == "uid":
- os.fchown(s.fileno(), self.UidForNetid(netid), -1)
-
- return s
-
- def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
- if netid is not None:
- pktinfo = MakePktInfo(version, None, self.ifindices[netid])
- cmsg_level, cmsg_name = {
- 4: (net_test.SOL_IP, IP_PKTINFO),
- 6: (net_test.SOL_IPV6, IPV6_PKTINFO)}[version]
- cmsgs.append((cmsg_level, cmsg_name, pktinfo))
- csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM)
-
- def ReceiveEtherPacketOn(self, netid, packet):
- posix.write(self.tuns[netid].fileno(), str(packet))
-
- def ReceivePacketOn(self, netid, ip_packet):
- routermac = self.RouterMacAddress(netid)
- mymac = self.MyMacAddress(netid)
- packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
- self.ReceiveEtherPacketOn(netid, packet)
-
- def ReadAllPacketsOn(self, netid, include_multicast=False):
- packets = []
- while True:
- try:
- packet = posix.read(self.tuns[netid].fileno(), 4096)
- if not packet:
- break
- ether = scapy.Ether(packet)
- # Multicast frames are frames where the first byte of the destination
- # MAC address has 1 in the least-significant bit.
- if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
- packets.append(ether.payload)
- except OSError, e:
- # EAGAIN means there are no more packets waiting.
- if re.match(e.message, os.strerror(errno.EAGAIN)):
- break
- # Anything else is unexpected.
- else:
- raise e
- return packets
-
- def ClearTunQueues(self):
- # Keep reading packets on all netids until we get no packets on any of them.
- waiting = None
- while waiting != 0:
- waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
-
- def assertPacketMatches(self, expected, actual):
- # The expected packet is just a rough sketch of the packet we expect to
- # receive. For example, it doesn't contain fields we can't predict, such as
- # initial TCP sequence numbers, or that depend on the host implementation
- # and settings, such as TCP options. To check whether the packet matches
- # what we expect, instead of just checking all the known fields one by one,
- # we blank out fields in the actual packet and then compare the whole
- # packets to each other as strings. Because we modify the actual packet,
- # make a copy here.
- actual = actual.copy()
-
- # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
- actualip = actual.getlayer("IP")
- expectedip = expected.getlayer("IP")
- if actualip and expectedip:
- actualip.id = expectedip.id
- actualip.flags &= 5
- actualip.chksum = None # Change the header, recalculate the checksum.
-
- # Blank out the flow label, since new kernels randomize it by default.
- actualipv6 = actual.getlayer("IPv6")
- expectedipv6 = expected.getlayer("IPv6")
- if actualipv6 and expectedipv6:
- actualipv6.fl = expectedipv6.fl
-
- # Blank out UDP fields that we can't predict (e.g., the source port for
- # kernel-originated packets).
- actualudp = actual.getlayer("UDP")
- expectedudp = expected.getlayer("UDP")
- if actualudp and expectedudp:
- if expectedudp.sport is None:
- actualudp.sport = None
- actualudp.chksum = None
-
- # Since the TCP code below messes with options, recalculate the length.
- if actualip:
- actualip.len = None
- if actualipv6:
- actualipv6.plen = None
-
- # Blank out TCP fields that we can't predict.
- actualtcp = actual.getlayer("TCP")
- expectedtcp = expected.getlayer("TCP")
- if actualtcp and expectedtcp:
- actualtcp.dataofs = expectedtcp.dataofs
- actualtcp.options = expectedtcp.options
- actualtcp.window = expectedtcp.window
- if expectedtcp.sport is None:
- actualtcp.sport = None
- if expectedtcp.seq is None:
- actualtcp.seq = None
- if expectedtcp.ack is None:
- actualtcp.ack = None
- actualtcp.chksum = None
-
- # Serialize the packet so that expected packet fields that are only set when
- # a packet is serialized e.g., the checksum) are filled in.
- expected_real = expected.__class__(str(expected))
- actual_real = actual.__class__(str(actual))
- # repr() can be expensive. Call it only if the test is going to fail and we
- # want to see the error.
- if expected_real != actual_real:
- self.assertEquals(repr(expected_real), repr(actual_real))
-
- def PacketMatches(self, expected, actual):
- try:
- self.assertPacketMatches(expected, actual)
- return True
- except AssertionError:
- return False
-
- def ExpectNoPacketsOn(self, netid, msg):
- packets = self.ReadAllPacketsOn(netid)
- if packets:
- firstpacket = repr(packets[0])
- else:
- firstpacket = ""
- self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
-
- def ExpectPacketOn(self, netid, msg, expected):
- # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
- # multicast packets unless the packet we expect to see is a multicast
- # packet. For now the only tests that use this are IPv6.
- ipv6 = expected.getlayer("IPv6")
- if ipv6 and ipv6.dst.startswith("ff"):
- include_multicast = True
- else:
- include_multicast = False
-
- packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
- self.assertTrue(packets, msg + ": received no packets")
-
- # If we receive a packet that matches what we expected, return it.
- for packet in packets:
- if self.PacketMatches(expected, packet):
- return packet
-
- # None of the packets matched. Call assertPacketMatches to output a diff
- # between the expected packet and the last packet we received. In theory,
- # we'd output a diff to the packet that's the best match for what we
- # expected, but this is good enough for now.
- try:
- self.assertPacketMatches(expected, packets[-1])
- except Exception, e:
- raise UnexpectedPacketError(
- "%s: diff with last packet:\n%s" % (msg, e.message))
-
- def Combinations(self, version):
- """Produces a list of combinations to test."""
- combinations = []
-
- # Check packets addressed to the IP addresses of all our interfaces...
- for dest_ip_netid in self.tuns:
- ip_if = self.GetInterfaceName(dest_ip_netid)
- myaddr = self.MyAddress(version, dest_ip_netid)
- remoteaddr = self.GetRemoteAddress(version)
-
- # ... coming in on all our interfaces.
- for netid in self.tuns:
- iif = self.GetInterfaceName(netid)
- combinations.append((netid, iif, ip_if, myaddr, remoteaddr))
-
- return combinations
-
- def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc):
- msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra)
- if reply_desc:
- msg += ": Expecting %s on %s" % (reply_desc, iif)
- else:
- msg += ": Expecting no packets on %s" % iif
- return msg
-
- def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
- self.ReceivePacketOn(netid, packet)
- if reply:
- return self.ExpectPacketOn(netid, msg, reply)
- else:
- self.ExpectNoPacketsOn(netid, msg)
- return None
diff --git a/tests/net_test/multinetwork_test.py b/tests/net_test/multinetwork_test.py
deleted file mode 100755
index 660fdf6..0000000
--- a/tests/net_test/multinetwork_test.py
+++ /dev/null
@@ -1,926 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import errno
-import os
-import random
-from socket import * # pylint: disable=wildcard-import
-import struct
-import time # pylint: disable=unused-import
-import unittest
-
-from scapy import all as scapy
-
-import iproute
-import multinetwork_base
-import net_test
-import packets
-
-# For brevity.
-UDP_PAYLOAD = net_test.UDP_PAYLOAD
-
-IPV6_FLOWINFO = 11
-
-IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
-IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
-SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
-TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
-
-# The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
-HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
-
-
-class ConfigurationError(AssertionError):
- pass
-
-
-class InboundMarkingTest(multinetwork_base.MultiNetworkBaseTest):
-
- @classmethod
- def _SetInboundMarking(cls, netid, is_add):
- for version in [4, 6]:
- # Run iptables to set up incoming packet marking.
- iface = cls.GetInterfaceName(netid)
- add_del = "-A" if is_add else "-D"
- iptables = {4: "iptables", 6: "ip6tables"}[version]
- args = "%s %s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
- iptables, add_del, iface, netid)
- iptables = "/sbin/" + iptables
- ret = os.spawnvp(os.P_WAIT, iptables, args.split(" "))
- if ret:
- raise ConfigurationError("Setup command failed: %s" % args)
-
- @classmethod
- def setUpClass(cls):
- super(InboundMarkingTest, cls).setUpClass()
- for netid in cls.tuns:
- cls._SetInboundMarking(netid, True)
-
- @classmethod
- def tearDownClass(cls):
- for netid in cls.tuns:
- cls._SetInboundMarking(netid, False)
- super(InboundMarkingTest, cls).tearDownClass()
-
- @classmethod
- def SetMarkReflectSysctls(cls, value):
- cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
- try:
- cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
- except IOError:
- # This does not exist if we use the version of the patch that uses a
- # common sysctl for IPv4 and IPv6.
- pass
-
-
-class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
-
- # How many times to run outgoing packet tests.
- ITERATIONS = 5
-
- def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet):
- s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
-
- myaddr = self.MyAddress(version, netid)
- s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
- s.bind((myaddr, packets.PING_IDENT))
- net_test.SetSocketTos(s, packets.PING_TOS)
-
- desc, expected = packets.ICMPEcho(version, myaddr, dstaddr)
- msg = "IPv%d ping: expected %s on %s" % (
- version, desc, self.GetInterfaceName(netid))
-
- s.sendto(packet + packets.PING_PAYLOAD, (dstaddr, 19321))
-
- self.ExpectPacketOn(netid, msg, expected)
-
- def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr):
- s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
-
- if version == 6 and dstaddr.startswith("::ffff"):
- version = 4
- myaddr = self.MyAddress(version, netid)
- desc, expected = packets.SYN(53, version, myaddr, dstaddr,
- sport=None, seq=None)
-
- # Non-blocking TCP connects always return EINPROGRESS.
- self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
- msg = "IPv%s TCP connect: expected %s on %s" % (
- version, desc, self.GetInterfaceName(netid))
- self.ExpectPacketOn(netid, msg, expected)
- s.close()
-
- def CheckUDPPacket(self, version, netid, routing_mode, dstaddr):
- s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
-
- if version == 6 and dstaddr.startswith("::ffff"):
- version = 4
- myaddr = self.MyAddress(version, netid)
- desc, expected = packets.UDP(version, myaddr, dstaddr, sport=None)
- msg = "IPv%s UDP %%s: expected %s on %s" % (
- version, desc, self.GetInterfaceName(netid))
-
- s.sendto(UDP_PAYLOAD, (dstaddr, 53))
- self.ExpectPacketOn(netid, msg % "sendto", expected)
-
- # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
- if routing_mode != "ucast_oif":
- s.connect((dstaddr, 53))
- s.send(UDP_PAYLOAD)
- self.ExpectPacketOn(netid, msg % "connect/send", expected)
- s.close()
-
- def CheckRawGrePacket(self, version, netid, routing_mode, dstaddr):
- s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
-
- inner_version = {4: 6, 6: 4}[version]
- inner_src = self.MyAddress(inner_version, netid)
- inner_dst = self.GetRemoteAddress(inner_version)
- inner = str(packets.UDP(inner_version, inner_src, inner_dst, sport=None)[1])
-
- ethertype = {4: net_test.ETH_P_IP, 6: net_test.ETH_P_IPV6}[inner_version]
- # A GRE header can be as simple as two zero bytes and the ethertype.
- packet = struct.pack("!i", ethertype) + inner
- myaddr = self.MyAddress(version, netid)
-
- s.sendto(packet, (dstaddr, IPPROTO_GRE))
- desc, expected = packets.GRE(version, myaddr, dstaddr, ethertype, inner)
- msg = "Raw IPv%d GRE with inner IPv%d UDP: expected %s on %s" % (
- version, inner_version, desc, self.GetInterfaceName(netid))
- self.ExpectPacketOn(netid, msg, expected)
-
- def CheckOutgoingPackets(self, routing_mode):
- v4addr = self.IPV4_ADDR
- v6addr = self.IPV6_ADDR
- v4mapped = "::ffff:" + v4addr
-
- for _ in xrange(self.ITERATIONS):
- for netid in self.tuns:
-
- self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING)
- # Kernel bug.
- if routing_mode != "oif":
- self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING)
-
- # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
- if routing_mode != "ucast_oif":
- self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr)
- self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr)
- self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped)
-
- self.CheckUDPPacket(4, netid, routing_mode, v4addr)
- self.CheckUDPPacket(6, netid, routing_mode, v6addr)
- self.CheckUDPPacket(6, netid, routing_mode, v4mapped)
-
- # Creating raw sockets on non-root UIDs requires properly setting
- # capabilities, which is hard to do from Python.
- # IP_UNICAST_IF is not supported on raw sockets.
- if routing_mode not in ["uid", "ucast_oif"]:
- self.CheckRawGrePacket(4, netid, routing_mode, v4addr)
- self.CheckRawGrePacket(6, netid, routing_mode, v6addr)
-
- def testMarkRouting(self):
- """Checks that socket marking selects the right outgoing interface."""
- self.CheckOutgoingPackets("mark")
-
- @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
- def testUidRouting(self):
- """Checks that UID routing selects the right outgoing interface."""
- self.CheckOutgoingPackets("uid")
-
- def testOifRouting(self):
- """Checks that oif routing selects the right outgoing interface."""
- self.CheckOutgoingPackets("oif")
-
- @unittest.skipUnless(HAVE_UNICAST_IF, "no support for UNICAST_IF")
- def testUcastOifRouting(self):
- """Checks that ucast oif routing selects the right outgoing interface."""
- self.CheckOutgoingPackets("ucast_oif")
-
- def CheckRemarking(self, version, use_connect):
- # Remarking or resetting UNICAST_IF on connected sockets does not work.
- if use_connect:
- modes = ["oif"]
- else:
- modes = ["mark", "oif"]
- if HAVE_UNICAST_IF:
- modes += ["ucast_oif"]
-
- for mode in modes:
- s = net_test.UDPSocket(self.GetProtocolFamily(version))
-
- # Figure out what packets to expect.
- unspec = {4: "0.0.0.0", 6: "::"}[version]
- sport = packets.RandomPort()
- s.bind((unspec, sport))
- dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
- desc, expected = packets.UDP(version, unspec, dstaddr, sport)
-
- # If we're testing connected sockets, connect the socket on the first
- # netid now.
- if use_connect:
- netid = self.tuns.keys()[0]
- self.SelectInterface(s, netid, mode)
- s.connect((dstaddr, 53))
- expected.src = self.MyAddress(version, netid)
-
- # For each netid, select that network without closing the socket, and
- # check that the packets sent on that socket go out on the right network.
- for netid in self.tuns:
- self.SelectInterface(s, netid, mode)
- if not use_connect:
- expected.src = self.MyAddress(version, netid)
- s.sendto(UDP_PAYLOAD, (dstaddr, 53))
- connected_str = "Connected" if use_connect else "Unconnected"
- msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
- connected_str, version, mode, desc, self.GetInterfaceName(netid))
- self.ExpectPacketOn(netid, msg, expected)
- self.SelectInterface(s, None, mode)
-
- def testIPv4Remarking(self):
- """Checks that updating the mark on an IPv4 socket changes routing."""
- self.CheckRemarking(4, False)
- self.CheckRemarking(4, True)
-
- def testIPv6Remarking(self):
- """Checks that updating the mark on an IPv6 socket changes routing."""
- self.CheckRemarking(6, False)
- self.CheckRemarking(6, True)
-
- def testIPv6StickyPktinfo(self):
- for _ in xrange(self.ITERATIONS):
- for netid in self.tuns:
- s = net_test.UDPSocket(AF_INET6)
-
- # Set a flowlabel.
- net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
- s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
-
- # Set some destination options.
- nonce = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c"
- dstopts = "".join([
- "\x11\x02", # Next header=UDP, 24 bytes of options.
- "\x01\x06", "\x00" * 6, # PadN, 6 bytes of padding.
- "\x8b\x0c", # ILNP nonce, 12 bytes.
- nonce
- ])
- s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, dstopts)
- s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 255)
-
- pktinfo = multinetwork_base.MakePktInfo(6, None, self.ifindices[netid])
-
- # Set the sticky pktinfo option.
- s.setsockopt(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)
-
- # Specify the flowlabel in the destination address.
- s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 53, 0xdead, 0))
-
- sport = s.getsockname()[1]
- srcaddr = self.MyAddress(6, netid)
- expected = (scapy.IPv6(src=srcaddr, dst=net_test.IPV6_ADDR,
- fl=0xdead, hlim=255) /
- scapy.IPv6ExtHdrDestOpt(
- options=[scapy.PadN(optdata="\x00\x00\x00\x00\x00\x00"),
- scapy.HBHOptUnknown(otype=0x8b,
- optdata=nonce)]) /
- scapy.UDP(sport=sport, dport=53) /
- UDP_PAYLOAD)
- msg = "IPv6 UDP using sticky pktinfo: expected UDP packet on %s" % (
- self.GetInterfaceName(netid))
- self.ExpectPacketOn(netid, msg, expected)
-
- def CheckPktinfoRouting(self, version):
- for _ in xrange(self.ITERATIONS):
- for netid in self.tuns:
- family = self.GetProtocolFamily(version)
- s = net_test.UDPSocket(family)
-
- if version == 6:
- # Create a flowlabel so we can use it.
- net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
-
- # Specify some arbitrary options.
- cmsgs = [
- (net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
- (net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
- (net_test.SOL_IPV6, IPV6_FLOWINFO, int(htonl(0xbeef))),
- ]
- else:
- # Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
- cmsgs = []
- s.setsockopt(net_test.SOL_IP, IP_TTL, 39)
- s.setsockopt(net_test.SOL_IP, IP_TOS, 0x83)
-
- dstaddr = self.GetRemoteAddress(version)
- self.SendOnNetid(version, s, dstaddr, 53, netid, UDP_PAYLOAD, cmsgs)
-
- sport = s.getsockname()[1]
- srcaddr = self.MyAddress(version, netid)
-
- desc, expected = packets.UDPWithOptions(version, srcaddr, dstaddr,
- sport=sport)
-
- msg = "IPv%d UDP using pktinfo routing: expected %s on %s" % (
- version, desc, self.GetInterfaceName(netid))
- self.ExpectPacketOn(netid, msg, expected)
-
- def testIPv4PktinfoRouting(self):
- self.CheckPktinfoRouting(4)
-
- def testIPv6PktinfoRouting(self):
- self.CheckPktinfoRouting(6)
-
-
-class MarkTest(InboundMarkingTest):
-
- def CheckReflection(self, version, gen_packet, gen_reply):
- """Checks that replies go out on the same interface as the original.
-
- For each combination:
- - Calls gen_packet to generate a packet to that IP address.
- - Writes the packet generated by gen_packet on the given tun
- interface, causing the kernel to receive it.
- - Checks that the kernel's reply matches the packet generated by
- gen_reply.
-
- Args:
- version: An integer, 4 or 6.
- gen_packet: A function taking an IP version (an integer), a source
- address and a destination address (strings), and returning a scapy
- packet.
- gen_reply: A function taking the same arguments as gen_packet,
- plus a scapy packet, and returning a scapy packet.
- """
- for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
- # Generate a test packet.
- desc, packet = gen_packet(version, remoteaddr, myaddr)
-
- # Test with mark reflection enabled and disabled.
- for reflect in [0, 1]:
- self.SetMarkReflectSysctls(reflect)
- # HACK: IPv6 ping replies always do a routing lookup with the
- # interface the ping came in on. So even if mark reflection is not
- # working, IPv6 ping replies will be properly reflected. Don't
- # fail when that happens.
- if reflect or desc == "ICMPv6 echo":
- reply_desc, reply = gen_reply(version, myaddr, remoteaddr, packet)
- else:
- reply_desc, reply = None, None
-
- msg = self._FormatMessage(iif, ip_if, "reflect=%d" % reflect,
- desc, reply_desc)
- self._ReceiveAndExpectResponse(netid, packet, reply, msg)
-
- def SYNToClosedPort(self, *args):
- return packets.SYN(999, *args)
-
- def testIPv4ICMPErrorsReflectMark(self):
- self.CheckReflection(4, packets.UDP, packets.ICMPPortUnreachable)
-
- def testIPv6ICMPErrorsReflectMark(self):
- self.CheckReflection(6, packets.UDP, packets.ICMPPortUnreachable)
-
- def testIPv4PingRepliesReflectMarkAndTos(self):
- self.CheckReflection(4, packets.ICMPEcho, packets.ICMPReply)
-
- def testIPv6PingRepliesReflectMarkAndTos(self):
- self.CheckReflection(6, packets.ICMPEcho, packets.ICMPReply)
-
- def testIPv4RSTsReflectMark(self):
- self.CheckReflection(4, self.SYNToClosedPort, packets.RST)
-
- def testIPv6RSTsReflectMark(self):
- self.CheckReflection(6, self.SYNToClosedPort, packets.RST)
-
-
-class TCPAcceptTest(InboundMarkingTest):
-
- MODE_BINDTODEVICE = "SO_BINDTODEVICE"
- MODE_INCOMING_MARK = "incoming mark"
- MODE_EXPLICIT_MARK = "explicit mark"
- MODE_UID = "uid"
-
- @classmethod
- def setUpClass(cls):
- super(TCPAcceptTest, cls).setUpClass()
-
- # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
- # will accept both IPv4 and IPv6 connections. We do this here instead of in
- # each test so we can use the same socket every time. That way, if a kernel
- # bug causes incoming packets to mark the listening socket instead of the
- # accepted socket, the test will fail as soon as the next address/interface
- # combination is tried.
- cls.listenport = 1234
- cls.listensocket = net_test.IPv6TCPSocket()
- cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
- cls.listensocket.bind(("::", cls.listenport))
- cls.listensocket.listen(100)
-
- def BounceSocket(self, s):
- """Attempts to invalidate a socket's destination cache entry."""
- if s.family == AF_INET:
- tos = s.getsockopt(SOL_IP, IP_TOS)
- s.setsockopt(net_test.SOL_IP, IP_TOS, 53)
- s.setsockopt(net_test.SOL_IP, IP_TOS, tos)
- else:
- # UDP, 8 bytes dstopts; PAD1, 4 bytes padding; 4 bytes zeros.
- pad8 = "".join(["\x11\x00", "\x01\x04", "\x00" * 4])
- s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, pad8)
- s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, "")
-
- def _SetTCPMarkAcceptSysctl(self, value):
- self.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
-
- def CheckTCPConnection(self, mode, listensocket, netid, version,
- myaddr, remoteaddr, packet, reply, msg):
- establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
-
- # Attempt to confuse the kernel.
- self.BounceSocket(listensocket)
-
- self.ReceivePacketOn(netid, establishing_ack)
-
- # If we're using UID routing, the accept() call has to be run as a UID that
- # is routed to the specified netid, because the UID of the socket returned
- # by accept() is the effective UID of the process that calls it. It doesn't
- # need to be the same UID; any UID that selects the same interface will do.
- with net_test.RunAsUid(self.UidForNetid(netid)):
- s, _ = listensocket.accept()
-
- try:
- # Check that data sent on the connection goes out on the right interface.
- desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
- payload=UDP_PAYLOAD)
- s.send(UDP_PAYLOAD)
- self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
- self.BounceSocket(s)
-
- # Keep up our end of the conversation.
- ack = packets.ACK(version, remoteaddr, myaddr, data)[1]
- self.BounceSocket(listensocket)
- self.ReceivePacketOn(netid, ack)
-
- mark = self.GetSocketMark(s)
- finally:
- self.BounceSocket(s)
- s.close()
-
- if mode == self.MODE_INCOMING_MARK:
- self.assertEquals(netid, mark,
- msg + ": Accepted socket: Expected mark %d, got %d" % (
- netid, mark))
- elif mode != self.MODE_EXPLICIT_MARK:
- self.assertEquals(0, self.GetSocketMark(listensocket))
-
- # Check the FIN was sent on the right interface, and ack it. We don't expect
- # this to fail because by the time the connection is established things are
- # likely working, but a) extra tests are always good and b) extra packets
- # like the FIN (and retransmitted FINs) could cause later tests that expect
- # no packets to fail.
- desc, fin = packets.FIN(version, myaddr, remoteaddr, ack)
- self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
-
- desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
- self.ReceivePacketOn(netid, finack)
-
- # Since we called close() earlier, the userspace socket object is gone, so
- # the socket has no UID. If we're doing UID routing, the ack might be routed
- # incorrectly. Not much we can do here.
- desc, finackack = packets.ACK(version, myaddr, remoteaddr, finack)
- if mode != self.MODE_UID:
- self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
- else:
- self.ClearTunQueues()
-
- def CheckTCP(self, version, modes):
- """Checks that incoming TCP connections work.
-
- Args:
- version: An integer, 4 or 6.
- modes: A list of modes to excercise.
- """
- for syncookies in [0, 2]:
- for mode in modes:
- for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
- if mode == self.MODE_UID:
- listensocket = self.BuildSocket(6, net_test.TCPSocket, netid, mode)
- listensocket.listen(100)
- else:
- listensocket = self.listensocket
-
- listenport = listensocket.getsockname()[1]
-
- accept_sysctl = 1 if mode == self.MODE_INCOMING_MARK else 0
- self._SetTCPMarkAcceptSysctl(accept_sysctl)
-
- bound_dev = iif if mode == self.MODE_BINDTODEVICE else None
- self.BindToDevice(listensocket, bound_dev)
-
- mark = netid if mode == self.MODE_EXPLICIT_MARK else 0
- self.SetSocketMark(listensocket, mark)
-
- # Generate the packet here instead of in the outer loop, so
- # subsequent TCP connections use different source ports and
- # retransmissions from old connections don't confuse subsequent
- # tests.
- desc, packet = packets.SYN(listenport, version, remoteaddr, myaddr)
-
- if mode:
- reply_desc, reply = packets.SYNACK(version, myaddr, remoteaddr,
- packet)
- else:
- reply_desc, reply = None, None
-
- extra = "mode=%s, syncookies=%d" % (mode, syncookies)
- msg = self._FormatMessage(iif, ip_if, extra, desc, reply_desc)
- reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
- if reply:
- self.CheckTCPConnection(mode, listensocket, netid, version, myaddr,
- remoteaddr, packet, reply, msg)
-
- def testBasicTCP(self):
- self.CheckTCP(4, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
- self.CheckTCP(6, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
-
- def testIPv4MarkAccept(self):
- self.CheckTCP(4, [self.MODE_INCOMING_MARK])
-
- def testIPv6MarkAccept(self):
- self.CheckTCP(6, [self.MODE_INCOMING_MARK])
-
- @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
- def testIPv4UidAccept(self):
- self.CheckTCP(4, [self.MODE_UID])
-
- @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
- def testIPv6UidAccept(self):
- self.CheckTCP(6, [self.MODE_UID])
-
- def testIPv6ExplicitMark(self):
- self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
-
-
-class RATest(multinetwork_base.MultiNetworkBaseTest):
-
- def testDoesNotHaveObsoleteSysctl(self):
- self.assertFalse(os.path.isfile(
- "/proc/sys/net/ipv6/route/autoconf_table_offset"))
-
- @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
- "no support for per-table autoconf")
- def testPurgeDefaultRouters(self):
-
- def CheckIPv6Connectivity(expect_connectivity):
- for netid in self.NETIDS:
- s = net_test.UDPSocket(AF_INET6)
- self.SetSocketMark(s, netid)
- if expect_connectivity:
- self.assertTrue(s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 1234)))
- else:
- self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, UDP_PAYLOAD,
- (net_test.IPV6_ADDR, 1234))
-
- try:
- CheckIPv6Connectivity(True)
- self.SetIPv6SysctlOnAllIfaces("accept_ra", 1)
- self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
- CheckIPv6Connectivity(False)
- finally:
- self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
- for netid in self.NETIDS:
- self.SendRA(netid)
- CheckIPv6Connectivity(True)
-
- def testOnlinkCommunication(self):
- """Checks that on-link communication goes direct and not through routers."""
- for netid in self.tuns:
- # Send a UDP packet to a random on-link destination.
- s = net_test.UDPSocket(AF_INET6)
- iface = self.GetInterfaceName(netid)
- self.BindToDevice(s, iface)
- # dstaddr can never be our address because GetRandomDestination only fills
- # in the lower 32 bits, but our address has 0xff in the byte before that
- # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
- dstaddr = self.GetRandomDestination(self.IPv6Prefix(netid))
- s.sendto(UDP_PAYLOAD, (dstaddr, 53))
-
- # Expect an NS for that destination on the interface.
- myaddr = self.MyAddress(6, netid)
- mymac = self.MyMacAddress(netid)
- desc, expected = packets.NS(myaddr, dstaddr, mymac)
- msg = "Sending UDP packet to on-link destination: expecting %s" % desc
- time.sleep(0.0001) # Required to make the test work on kernel 3.1(!)
- self.ExpectPacketOn(netid, msg, expected)
-
- # Send an NA.
- tgtmac = "02:00:00:00:%02x:99" % netid
- _, reply = packets.NA(dstaddr, myaddr, tgtmac)
- # Don't use ReceivePacketOn, since that uses the router's MAC address as
- # the source. Instead, construct our own Ethernet header with source
- # MAC of tgtmac.
- reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
- self.ReceiveEtherPacketOn(netid, reply)
-
- # Expect the kernel to send the original UDP packet now that the ND cache
- # entry has been populated.
- sport = s.getsockname()[1]
- desc, expected = packets.UDP(6, myaddr, dstaddr, sport=sport)
- msg = "After NA response, expecting %s" % desc
- self.ExpectPacketOn(netid, msg, expected)
-
- # This test documents a known issue: routing tables are never deleted.
- @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
- "no support for per-table autoconf")
- def testLeftoverRoutes(self):
- def GetNumRoutes():
- return len(open("/proc/net/ipv6_route").readlines())
-
- num_routes = GetNumRoutes()
- for i in xrange(10, 20):
- try:
- self.tuns[i] = self.CreateTunInterface(i)
- self.SendRA(i)
- self.tuns[i].close()
- finally:
- del self.tuns[i]
- self.assertLess(num_routes, GetNumRoutes())
-
-
-class PMTUTest(InboundMarkingTest):
-
- PAYLOAD_SIZE = 1400
-
- # Socket options to change PMTU behaviour.
- IP_MTU_DISCOVER = 10
- IP_PMTUDISC_DO = 1
- IPV6_DONTFRAG = 62
-
- # Socket options to get the MTU.
- IP_MTU = 14
- IPV6_PATHMTU = 61
-
- def GetSocketMTU(self, version, s):
- if version == 6:
- ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
- unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
- return mtu
- else:
- return s.getsockopt(net_test.SOL_IP, self.IP_MTU)
-
- def DisableFragmentationAndReportErrors(self, version, s):
- if version == 4:
- s.setsockopt(net_test.SOL_IP, self.IP_MTU_DISCOVER, self.IP_PMTUDISC_DO)
- s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
- else:
- s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
- s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
-
- def CheckPMTU(self, version, use_connect, modes):
-
- def SendBigPacket(version, s, dstaddr, netid, payload):
- if use_connect:
- s.send(payload)
- else:
- self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
-
- for netid in self.tuns:
- for mode in modes:
- s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
- self.DisableFragmentationAndReportErrors(version, s)
-
- srcaddr = self.MyAddress(version, netid)
- dst_prefix, intermediate = {
- 4: ("172.19.", "172.16.9.12"),
- 6: ("2001:db8::", "2001:db8::1")
- }[version]
- dstaddr = self.GetRandomDestination(dst_prefix)
-
- if use_connect:
- s.connect((dstaddr, 1234))
-
- payload = self.PAYLOAD_SIZE * "a"
-
- # Send a packet and receive a packet too big.
- SendBigPacket(version, s, dstaddr, netid, payload)
- received = self.ReadAllPacketsOn(netid)
- self.assertEquals(1, len(received))
- _, toobig = packets.ICMPPacketTooBig(version, intermediate, srcaddr,
- received[0])
- self.ReceivePacketOn(netid, toobig)
-
- # Check that another send on the same socket returns EMSGSIZE.
- self.assertRaisesErrno(
- errno.EMSGSIZE,
- SendBigPacket, version, s, dstaddr, netid, payload)
-
- # If this is a connected socket, make sure the socket MTU was set.
- # Note that in IPv4 this only started working in Linux 3.6!
- if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
- self.assertEquals(1280, self.GetSocketMTU(version, s))
-
- s.close()
-
- # Check that other sockets pick up the PMTU we have been told about by
- # connecting another socket to the same destination and getting its MTU.
- # This new socket can use any method to select its outgoing interface;
- # here we use a mark for simplicity.
- s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
- s2.connect((dstaddr, 1234))
- self.assertEquals(1280, self.GetSocketMTU(version, s2))
-
- # Also check the MTU reported by ip route get, this time using the oif.
- routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
- self.assertTrue(routes)
- route = routes[0]
- rtmsg, attributes = route
- self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
- metrics = attributes["RTA_METRICS"]
- self.assertEquals(metrics["RTAX_MTU"], 1280)
-
- def testIPv4BasicPMTU(self):
- """Tests IPv4 path MTU discovery.
-
- Relevant kernel commits:
- upstream net-next:
- 6a66271 ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
-
- android-3.10:
- 4bc64dd ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
- """
-
- self.CheckPMTU(4, True, ["mark", "oif"])
- self.CheckPMTU(4, False, ["mark", "oif"])
-
- def testIPv6BasicPMTU(self):
- self.CheckPMTU(6, True, ["mark", "oif"])
- self.CheckPMTU(6, False, ["mark", "oif"])
-
- @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
- def testIPv4UIDPMTU(self):
- self.CheckPMTU(4, True, ["uid"])
- self.CheckPMTU(4, False, ["uid"])
-
- @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
- def testIPv6UIDPMTU(self):
- self.CheckPMTU(6, True, ["uid"])
- self.CheckPMTU(6, False, ["uid"])
-
- # Making Path MTU Discovery work on unmarked sockets requires that mark
- # reflection be enabled. Otherwise the kernel has no way to know what routing
- # table the original packet used, and thus it won't be able to clone the
- # correct route.
-
- def testIPv4UnmarkedSocketPMTU(self):
- self.SetMarkReflectSysctls(1)
- try:
- self.CheckPMTU(4, False, [None])
- finally:
- self.SetMarkReflectSysctls(0)
-
- def testIPv6UnmarkedSocketPMTU(self):
- self.SetMarkReflectSysctls(1)
- try:
- self.CheckPMTU(6, False, [None])
- finally:
- self.SetMarkReflectSysctls(0)
-
-
-@unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
-class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest):
- """Tests that per-UID routing works properly.
-
- Relevant kernel commits:
- android-3.4:
- 0b42874 net: core: Support UID-based routing.
- 0836a0c Handle 'sk' being NULL in UID-based routing.
-
- android-3.10:
- 99a6ea4 net: core: Support UID-based routing.
- 455b09d Handle 'sk' being NULL in UID-based routing.
- """
-
- def GetRulesAtPriority(self, version, priority):
- rules = self.iproute.DumpRules(version)
- out = [(rule, attributes) for rule, attributes in rules
- if attributes.get("FRA_PRIORITY", 0) == priority]
- return out
-
- def CheckInitialTablesHaveNoUIDs(self, version):
- rules = []
- for priority in [0, 32766, 32767]:
- rules.extend(self.GetRulesAtPriority(version, priority))
- for _, attributes in rules:
- self.assertNotIn("FRA_UID_START", attributes)
- self.assertNotIn("FRA_UID_END", attributes)
-
- def testIPv4InitialTablesHaveNoUIDs(self):
- self.CheckInitialTablesHaveNoUIDs(4)
-
- def testIPv6InitialTablesHaveNoUIDs(self):
- self.CheckInitialTablesHaveNoUIDs(6)
-
- def CheckGetAndSetRules(self, version):
- def Random():
- return random.randint(1000000, 2000000)
-
- start, end = tuple(sorted([Random(), Random()]))
- table = Random()
- priority = Random()
-
- try:
- self.iproute.UidRangeRule(version, True, start, end, table,
- priority=priority)
-
- rules = self.GetRulesAtPriority(version, priority)
- self.assertTrue(rules)
- _, attributes = rules[-1]
- self.assertEquals(priority, attributes["FRA_PRIORITY"])
- self.assertEquals(start, attributes["FRA_UID_START"])
- self.assertEquals(end, attributes["FRA_UID_END"])
- self.assertEquals(table, attributes["FRA_TABLE"])
- finally:
- self.iproute.UidRangeRule(version, False, start, end, table,
- priority=priority)
-
- def testIPv4GetAndSetRules(self):
- self.CheckGetAndSetRules(4)
-
- def testIPv6GetAndSetRules(self):
- self.CheckGetAndSetRules(6)
-
- def ExpectNoRoute(self, addr, oif, mark, uid):
- # The lack of a route may be either an error, or an unreachable route.
- try:
- routes = self.iproute.GetRoutes(addr, oif, mark, uid)
- rtmsg, _ = routes[0]
- self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
- except IOError, e:
- if int(e.errno) != -int(errno.ENETUNREACH):
- raise e
-
- def ExpectRoute(self, addr, oif, mark, uid):
- routes = self.iproute.GetRoutes(addr, oif, mark, uid)
- rtmsg, _ = routes[0]
- self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
-
- def CheckGetRoute(self, version, addr):
- self.ExpectNoRoute(addr, 0, 0, 0)
- for netid in self.NETIDS:
- uid = self.UidForNetid(netid)
- self.ExpectRoute(addr, 0, 0, uid)
- self.ExpectNoRoute(addr, 0, 0, 0)
-
- def testIPv4RouteGet(self):
- self.CheckGetRoute(4, net_test.IPV4_ADDR)
-
- def testIPv6RouteGet(self):
- self.CheckGetRoute(6, net_test.IPV6_ADDR)
-
-
-class RulesTest(net_test.NetworkTest):
-
- RULE_PRIORITY = 99999
-
- def setUp(self):
- self.iproute = iproute.IPRoute()
- for version in [4, 6]:
- self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
-
- def tearDown(self):
- for version in [4, 6]:
- self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
-
- def testRuleDeletionMatchesTable(self):
- for version in [4, 6]:
- # Add rules with mark 300 pointing at tables 301 and 302.
- # This checks for a kernel bug where deletion request for tables > 256
- # ignored the table.
- self.iproute.FwmarkRule(version, True, 300, 301,
- priority=self.RULE_PRIORITY)
- self.iproute.FwmarkRule(version, True, 300, 302,
- priority=self.RULE_PRIORITY)
- # Delete rule with mark 300 pointing at table 302.
- self.iproute.FwmarkRule(version, False, 300, 302,
- priority=self.RULE_PRIORITY)
- # Check that the rule pointing at table 301 is still around.
- attributes = [a for _, a in self.iproute.DumpRules(version)
- if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY]
- self.assertEquals(1, len(attributes))
- self.assertEquals(301, attributes[0]["FRA_TABLE"])
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/neighbour_test.py b/tests/net_test/neighbour_test.py
deleted file mode 100755
index 1e7739e..0000000
--- a/tests/net_test/neighbour_test.py
+++ /dev/null
@@ -1,297 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2015 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import errno
-import random
-from socket import * # pylint: disable=wildcard-import
-import time
-import unittest
-
-from scapy import all as scapy
-
-import multinetwork_base
-import net_test
-
-
-RTMGRP_NEIGH = 4
-
-NUD_INCOMPLETE = 0x01
-NUD_REACHABLE = 0x02
-NUD_STALE = 0x04
-NUD_DELAY = 0x08
-NUD_PROBE = 0x10
-NUD_FAILED = 0x20
-NUD_PERMANENT = 0x80
-
-
-# TODO: Support IPv4.
-class NeighbourTest(multinetwork_base.MultiNetworkBaseTest):
-
- # Set a 100-ms retrans timer so we can test for ND retransmits without
- # waiting too long. Apparently this cannot go below 500ms.
- RETRANS_TIME_MS = 500
-
- # This can only be in seconds, so 1000 is the minimum.
- DELAY_TIME_MS = 1000
-
- # Unfortunately, this must be above the delay timer or the kernel ND code will
- # not behave correctly (e.g., go straight from REACHABLE into DELAY). This is
- # is fuzzed by the kernel from 0.5x to 1.5x of its value, so we need a value
- # that's 2x the delay timer.
- REACHABLE_TIME_MS = 2 * DELAY_TIME_MS
-
- @classmethod
- def setUpClass(cls):
- super(NeighbourTest, cls).setUpClass()
- for netid in cls.tuns:
- iface = cls.GetInterfaceName(netid)
- # This can't be set in an RA.
- cls.SetSysctl(
- "/proc/sys/net/ipv6/neigh/%s/delay_first_probe_time" % iface,
- cls.DELAY_TIME_MS / 1000)
-
- def setUp(self):
- super(NeighbourTest, self).setUp()
-
- for netid in self.tuns:
- # Clear the ND cache entries for all routers, so each test starts with
- # the IPv6 default router in state STALE.
- addr = self._RouterAddress(netid, 6)
- ifindex = self.ifindices[netid]
- self.iproute.UpdateNeighbour(6, addr, None, ifindex, NUD_FAILED)
-
- # Configure IPv6 by sending an RA.
- self.SendRA(netid,
- retranstimer=self.RETRANS_TIME_MS,
- reachabletime=self.REACHABLE_TIME_MS)
-
- self.sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE)
- self.sock.bind((0, RTMGRP_NEIGH))
- net_test.SetNonBlocking(self.sock)
-
- self.netid = random.choice(self.tuns.keys())
- self.ifindex = self.ifindices[self.netid]
-
- def GetNeighbour(self, addr):
- version = 6 if ":" in addr else 4
- for msg, args in self.iproute.DumpNeighbours(version):
- if args["NDA_DST"] == addr:
- return msg, args
-
- def GetNdEntry(self, addr):
- return self.GetNeighbour(addr)
-
- def CheckNoNdEvents(self):
- self.assertRaisesErrno(errno.EAGAIN, self.sock.recvfrom, 4096, MSG_PEEK)
-
- def assertNeighbourState(self, state, addr):
- self.assertEquals(state, self.GetNdEntry(addr)[0].state)
-
- def assertNeighbourAttr(self, addr, name, value):
- self.assertEquals(value, self.GetNdEntry(addr)[1][name])
-
- def ExpectNeighbourNotification(self, addr, state, attrs=None):
- msg = self.sock.recv(4096)
- msg, actual_attrs = self.iproute.ParseNeighbourMessage(msg)
- self.assertEquals(addr, actual_attrs["NDA_DST"])
- self.assertEquals(state, msg.state)
- if attrs:
- for name in attrs:
- self.assertEquals(attrs[name], actual_attrs[name])
-
- def ExpectProbe(self, is_unicast, addr):
- version = 6 if ":" in addr else 4
- if version == 6:
- llsrc = self.MyMacAddress(self.netid)
- if is_unicast:
- src = self.MyLinkLocalAddress(self.netid)
- dst = addr
- else:
- solicited = inet_pton(AF_INET6, addr)
- last3bytes = tuple([ord(b) for b in solicited[-3:]])
- dst = "ff02::1:ff%02x:%02x%02x" % last3bytes
- src = self.MyAddress(6, self.netid)
- expected = (
- scapy.IPv6(src=src, dst=dst) /
- scapy.ICMPv6ND_NS(tgt=addr) /
- scapy.ICMPv6NDOptSrcLLAddr(lladdr=llsrc)
- )
- msg = "%s probe" % ("Unicast" if is_unicast else "Multicast")
- self.ExpectPacketOn(self.netid, msg, expected)
- else:
- raise NotImplementedError
-
- def ExpectUnicastProbe(self, addr):
- self.ExpectProbe(True, addr)
-
- def ExpectMulticastNS(self, addr):
- self.ExpectProbe(False, addr)
-
- def ReceiveUnicastAdvertisement(self, addr, mac, srcaddr=None, dstaddr=None,
- S=1, O=0, R=1):
- version = 6 if ":" in addr else 4
- if srcaddr is None:
- srcaddr = addr
- if dstaddr is None:
- dstaddr = self.MyLinkLocalAddress(self.netid)
- if version == 6:
- packet = (
- scapy.Ether(src=mac, dst=self.MyMacAddress(self.netid)) /
- scapy.IPv6(src=srcaddr, dst=dstaddr) /
- scapy.ICMPv6ND_NA(tgt=addr, S=S, O=O, R=R) /
- scapy.ICMPv6NDOptDstLLAddr(lladdr=mac)
- )
- self.ReceiveEtherPacketOn(self.netid, packet)
- else:
- raise NotImplementedError
-
- def MonitorSleepMs(self, interval, addr):
- slept = 0
- while slept < interval:
- sleep_ms = min(100, interval - slept)
- time.sleep(sleep_ms / 1000.0)
- slept += sleep_ms
- print self.GetNdEntry(addr)
-
- def MonitorSleep(self, intervalseconds, addr):
- self.MonitorSleepMs(intervalseconds * 1000, addr)
-
- def SleepMs(self, ms):
- time.sleep(ms / 1000.0)
-
- def testNotifications(self):
- """Tests neighbour notifications.
-
- Relevant kernel commits:
- upstream net-next:
- 765c9c6 neigh: Better handling of transition to NUD_PROBE state
- 53385d2 neigh: Netlink notification for administrative NUD state change
- (only checked on kernel v3.13+, not on v3.10)
-
- android-3.10:
- e4a6d6b neigh: Better handling of transition to NUD_PROBE state
-
- android-3.18:
- 2011e72 neigh: Better handling of transition to NUD_PROBE state
- """
-
- router4 = self._RouterAddress(self.netid, 4)
- router6 = self._RouterAddress(self.netid, 6)
- self.assertNeighbourState(NUD_PERMANENT, router4)
- self.assertNeighbourState(NUD_STALE, router6)
-
- # Send a packet and check that we go into DELAY.
- routing_mode = random.choice(["mark", "oif", "uid"])
- s = self.BuildSocket(6, net_test.UDPSocket, self.netid, routing_mode)
- s.connect((net_test.IPV6_ADDR, 53))
- s.send(net_test.UDP_PAYLOAD)
- self.assertNeighbourState(NUD_DELAY, router6)
-
- # Wait for the probe interval, then check that we're in PROBE, and that the
- # kernel has notified us.
- self.SleepMs(self.DELAY_TIME_MS)
- self.ExpectNeighbourNotification(router6, NUD_PROBE)
- self.assertNeighbourState(NUD_PROBE, router6)
- self.ExpectUnicastProbe(router6)
-
- # Respond to the NS and verify we're in REACHABLE again.
- self.ReceiveUnicastAdvertisement(router6, self.RouterMacAddress(self.netid))
- self.assertNeighbourState(NUD_REACHABLE, router6)
- if net_test.LINUX_VERSION >= (3, 13, 0):
- # commit 53385d2 (v3.13) "neigh: Netlink notification for administrative
- # NUD state change" produces notifications for NUD_REACHABLE, but these
- # are not generated on earlier kernels.
- self.ExpectNeighbourNotification(router6, NUD_REACHABLE)
-
- # Wait until the reachable time has passed, and verify we're in STALE.
- self.SleepMs(self.REACHABLE_TIME_MS * 1.5)
- self.assertNeighbourState(NUD_STALE, router6)
- self.ExpectNeighbourNotification(router6, NUD_STALE)
-
- # Send a packet, and verify we go into DELAY and then to PROBE.
- s.send(net_test.UDP_PAYLOAD)
- self.assertNeighbourState(NUD_DELAY, router6)
- self.SleepMs(self.DELAY_TIME_MS)
- self.assertNeighbourState(NUD_PROBE, router6)
- self.ExpectNeighbourNotification(router6, NUD_PROBE)
-
- # Wait for the probes to time out, and expect a FAILED notification.
- self.assertNeighbourAttr(router6, "NDA_PROBES", 1)
- self.ExpectUnicastProbe(router6)
-
- self.SleepMs(self.RETRANS_TIME_MS)
- self.ExpectUnicastProbe(router6)
- self.assertNeighbourAttr(router6, "NDA_PROBES", 2)
-
- self.SleepMs(self.RETRANS_TIME_MS)
- self.ExpectUnicastProbe(router6)
- self.assertNeighbourAttr(router6, "NDA_PROBES", 3)
-
- self.SleepMs(self.RETRANS_TIME_MS)
- self.assertNeighbourState(NUD_FAILED, router6)
- self.ExpectNeighbourNotification(router6, NUD_FAILED, {"NDA_PROBES": 3})
-
- def testRepeatedProbes(self):
- router4 = self._RouterAddress(self.netid, 4)
- router6 = self._RouterAddress(self.netid, 6)
- routermac = self.RouterMacAddress(self.netid)
- self.assertNeighbourState(NUD_PERMANENT, router4)
- self.assertNeighbourState(NUD_STALE, router6)
-
- def ForceProbe(addr, mac):
- self.iproute.UpdateNeighbour(6, addr, None, self.ifindex, NUD_PROBE)
- self.assertNeighbourState(NUD_PROBE, addr)
- self.SleepMs(1) # TODO: Why is this necessary?
- self.assertNeighbourState(NUD_PROBE, addr)
- self.ExpectUnicastProbe(addr)
- self.ReceiveUnicastAdvertisement(addr, mac)
- self.assertNeighbourState(NUD_REACHABLE, addr)
-
- for _ in xrange(5):
- ForceProbe(router6, routermac)
-
- def testIsRouterFlag(self):
- router6 = self._RouterAddress(self.netid, 6)
- self.assertNeighbourState(NUD_STALE, router6)
-
- # Get into FAILED.
- ifindex = self.ifindices[self.netid]
- self.iproute.UpdateNeighbour(6, router6, None, ifindex, NUD_FAILED)
- self.ExpectNeighbourNotification(router6, NUD_FAILED)
- self.assertNeighbourState(NUD_FAILED, router6)
-
- time.sleep(1)
-
- # Send another packet and expect a multicast NS.
- routing_mode = random.choice(["mark", "oif", "uid"])
- s = self.BuildSocket(6, net_test.UDPSocket, self.netid, routing_mode)
- s.connect((net_test.IPV6_ADDR, 53))
- s.send(net_test.UDP_PAYLOAD)
- self.ExpectMulticastNS(router6)
-
- # Receive a unicast NA with the R flag set to 0.
- self.ReceiveUnicastAdvertisement(router6, self.RouterMacAddress(self.netid),
- srcaddr=self._RouterAddress(self.netid, 6),
- dstaddr=self.MyAddress(6, self.netid),
- S=1, O=0, R=0)
-
- # Expect that this takes us to REACHABLE.
- self.ExpectNeighbourNotification(router6, NUD_REACHABLE)
- self.assertNeighbourState(NUD_REACHABLE, router6)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/net_test.py b/tests/net_test/net_test.py
deleted file mode 100755
index d7ea013..0000000
--- a/tests/net_test/net_test.py
+++ /dev/null
@@ -1,394 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import fcntl
-import os
-import random
-import re
-from socket import * # pylint: disable=wildcard-import
-import struct
-import unittest
-
-from scapy import all as scapy
-
-SOL_IPV6 = 41
-IP_RECVERR = 11
-IPV6_RECVERR = 25
-IP_TRANSPARENT = 19
-IPV6_TRANSPARENT = 75
-IPV6_TCLASS = 67
-IPV6_FLOWLABEL_MGR = 32
-IPV6_FLOWINFO_SEND = 33
-
-SO_BINDTODEVICE = 25
-SO_MARK = 36
-SO_PROTOCOL = 38
-SO_DOMAIN = 39
-
-ETH_P_IP = 0x0800
-ETH_P_IPV6 = 0x86dd
-
-IPPROTO_GRE = 47
-
-SIOCSIFHWADDR = 0x8924
-
-IPV6_FL_A_GET = 0
-IPV6_FL_A_PUT = 1
-IPV6_FL_A_RENEW = 1
-
-IPV6_FL_F_CREATE = 1
-IPV6_FL_F_EXCL = 2
-
-IPV6_FL_S_NONE = 0
-IPV6_FL_S_EXCL = 1
-IPV6_FL_S_ANY = 255
-
-IFNAMSIZ = 16
-
-IPV4_PING = "\x08\x00\x00\x00\x0a\xce\x00\x03"
-IPV6_PING = "\x80\x00\x00\x00\x0a\xce\x00\x03"
-
-IPV4_ADDR = "8.8.8.8"
-IPV6_ADDR = "2001:4860:4860::8888"
-
-IPV6_SEQ_DGRAM_HEADER = (" sl "
- "local_address "
- "remote_address "
- "st tx_queue rx_queue tr tm->when retrnsmt"
- " uid timeout inode ref pointer drops\n")
-
-# Arbitrary packet payload.
-UDP_PAYLOAD = str(scapy.DNS(rd=1,
- id=random.randint(0, 65535),
- qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
- qtype="AAAA")))
-
-# Unix group to use if we want to open sockets as non-root.
-AID_INET = 3003
-
-
-def LinuxVersion():
- # Example: "3.4.67-00753-gb7a556f".
- # Get the part before the dash.
- version = os.uname()[2].split("-")[0]
- # Convert it into a tuple such as (3, 4, 67). That allows comparing versions
- # using < and >, since tuples are compared lexicographically.
- version = tuple(int(i) for i in version.split("."))
- return version
-
-
-LINUX_VERSION = LinuxVersion()
-
-
-def SetSocketTimeout(sock, ms):
- s = ms / 1000
- us = (ms % 1000) * 1000
- sock.setsockopt(SOL_SOCKET, SO_RCVTIMEO, struct.pack("LL", s, us))
-
-
-def SetSocketTos(s, tos):
- level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
- option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
- s.setsockopt(level, option, tos)
-
-
-def SetNonBlocking(fd):
- flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
- fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
-
-
-# Convenience functions to create sockets.
-def Socket(family, sock_type, protocol):
- s = socket(family, sock_type, protocol)
- SetSocketTimeout(s, 1000)
- return s
-
-
-def PingSocket(family):
- proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
- return Socket(family, SOCK_DGRAM, proto)
-
-
-def IPv4PingSocket():
- return PingSocket(AF_INET)
-
-
-def IPv6PingSocket():
- return PingSocket(AF_INET6)
-
-
-def TCPSocket(family):
- s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
- SetNonBlocking(s.fileno())
- return s
-
-
-def IPv4TCPSocket():
- return TCPSocket(AF_INET)
-
-
-def IPv6TCPSocket():
- return TCPSocket(AF_INET6)
-
-
-def UDPSocket(family):
- return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
-
-
-def RawGRESocket(family):
- s = Socket(family, SOCK_RAW, IPPROTO_GRE)
- return s
-
-
-def DisableLinger(sock):
- sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0))
-
-
-def CreateSocketPair(family, socktype, addr):
- clientsock = socket(family, socktype, 0)
- listensock = socket(family, socktype, 0)
- listensock.bind((addr, 0))
- addr = listensock.getsockname()
- listensock.listen(1)
- clientsock.connect(addr)
- acceptedsock, _ = listensock.accept()
- DisableLinger(clientsock)
- DisableLinger(acceptedsock)
- listensock.close()
- return clientsock, acceptedsock
-
-
-def GetInterfaceIndex(ifname):
- s = IPv4PingSocket()
- ifr = struct.pack("%dsi" % IFNAMSIZ, ifname, 0)
- ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
- return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
-
-
-def SetInterfaceHWAddr(ifname, hwaddr):
- s = IPv4PingSocket()
- hwaddr = hwaddr.replace(":", "")
- hwaddr = hwaddr.decode("hex")
- if len(hwaddr) != 6:
- raise ValueError("Unknown hardware address length %d" % len(hwaddr))
- ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname, scapy.ARPHDR_ETHER, hwaddr)
- fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
-
-
-def SetInterfaceState(ifname, up):
- s = IPv4PingSocket()
- ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, 0)
- ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
- _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
- if up:
- flags |= scapy.IFF_UP
- else:
- flags &= ~scapy.IFF_UP
- ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, flags)
- ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr)
-
-
-def SetInterfaceUp(ifname):
- return SetInterfaceState(ifname, True)
-
-
-def SetInterfaceDown(ifname):
- return SetInterfaceState(ifname, False)
-
-
-def FormatProcAddress(unformatted):
- groups = []
- for i in xrange(0, len(unformatted), 4):
- groups.append(unformatted[i:i+4])
- formatted = ":".join(groups)
- # Compress the address.
- address = inet_ntop(AF_INET6, inet_pton(AF_INET6, formatted))
- return address
-
-
-def FormatSockStatAddress(address):
- if ":" in address:
- family = AF_INET6
- else:
- family = AF_INET
- binary = inet_pton(family, address)
- out = ""
- for i in xrange(0, len(binary), 4):
- out += "%08X" % struct.unpack("=L", binary[i:i+4])
- return out
-
-
-def GetLinkAddress(ifname, linklocal):
- addresses = open("/proc/net/if_inet6").readlines()
- for address in addresses:
- address = [s for s in address.strip().split(" ") if s]
- if address[5] == ifname:
- if (linklocal and address[0].startswith("fe80")
- or not linklocal and not address[0].startswith("fe80")):
- # Convert the address from raw hex to something with colons in it.
- return FormatProcAddress(address[0])
- return None
-
-
-def GetDefaultRoute(version=6):
- if version == 6:
- routes = open("/proc/net/ipv6_route").readlines()
- for route in routes:
- route = [s for s in route.strip().split(" ") if s]
- if (route[0] == "00000000000000000000000000000000" and route[1] == "00"
- # Routes in non-default tables end up in /proc/net/ipv6_route!!!
- and route[9] != "lo" and not route[9].startswith("nettest")):
- return FormatProcAddress(route[4]), route[9]
- raise ValueError("No IPv6 default route found")
- elif version == 4:
- routes = open("/proc/net/route").readlines()
- for route in routes:
- route = [s for s in route.strip().split("\t") if s]
- if route[1] == "00000000" and route[7] == "00000000":
- gw, iface = route[2], route[0]
- gw = inet_ntop(AF_INET, gw.decode("hex")[::-1])
- return gw, iface
- raise ValueError("No IPv4 default route found")
- else:
- raise ValueError("Don't know about IPv%s" % version)
-
-
-def GetDefaultRouteInterface():
- unused_gw, iface = GetDefaultRoute()
- return iface
-
-
-def MakeFlowLabelOption(addr, label):
- # struct in6_flowlabel_req {
- # struct in6_addr flr_dst;
- # __be32 flr_label;
- # __u8 flr_action;
- # __u8 flr_share;
- # __u16 flr_flags;
- # __u16 flr_expires;
- # __u16 flr_linger;
- # __u32 __flr_pad;
- # /* Options in format of IPV6_PKTOPTIONS */
- # };
- fmt = "16sIBBHHH4s"
- assert struct.calcsize(fmt) == 32
- addr = inet_pton(AF_INET6, addr)
- assert len(addr) == 16
- label = htonl(label & 0xfffff)
- action = IPV6_FL_A_GET
- share = IPV6_FL_S_ANY
- flags = IPV6_FL_F_CREATE
- pad = "\x00" * 4
- return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad)
-
-
-def SetFlowLabel(s, addr, label):
- opt = MakeFlowLabelOption(addr, label)
- s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt)
- # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1).
-
-
-# Determine network configuration.
-try:
- GetDefaultRoute(version=4)
- HAVE_IPV4 = True
-except ValueError:
- HAVE_IPV4 = False
-
-try:
- GetDefaultRoute(version=6)
- HAVE_IPV6 = True
-except ValueError:
- HAVE_IPV6 = False
-
-
-class RunAsUid(object):
- """Context guard to run a code block as a given UID."""
-
- def __init__(self, uid):
- self.uid = uid
-
- def __enter__(self):
- if self.uid:
- self.saved_uid = os.geteuid()
- self.saved_groups = os.getgroups()
- if self.uid:
- os.setgroups(self.saved_groups + [AID_INET])
- os.seteuid(self.uid)
-
- def __exit__(self, unused_type, unused_value, unused_traceback):
- if self.uid:
- os.seteuid(self.saved_uid)
- os.setgroups(self.saved_groups)
-
-
-class NetworkTest(unittest.TestCase):
-
- def assertRaisesErrno(self, err_num, f, *args):
- msg = os.strerror(err_num)
- self.assertRaisesRegexp(EnvironmentError, msg, f, *args)
-
- def ReadProcNetSocket(self, protocol):
- # Read file.
- filename = "/proc/net/%s" % protocol
- lines = open(filename).readlines()
-
- # Possibly check, and strip, header.
- if protocol in ["icmp6", "raw6", "udp6"]:
- self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0])
- lines = lines[1:]
-
- # Check contents.
- if protocol.endswith("6"):
- addrlen = 32
- else:
- addrlen = 8
-
- if protocol.startswith("tcp"):
- # Real sockets have 5 extra numbers, timewait sockets have none.
- end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+|)$"
- elif re.match("icmp|udp|raw", protocol):
- # Drops.
- end_regexp = " +([0-9]+) *$"
- else:
- raise ValueError("Don't know how to parse %s" % filename)
-
- regexp = re.compile(r" *(\d+): " # bucket
- "([0-9A-F]{%d}:[0-9A-F]{4}) " # srcaddr, port
- "([0-9A-F]{%d}:[0-9A-F]{4}) " # dstaddr, port
- "([0-9A-F][0-9A-F]) " # state
- "([0-9A-F]{8}:[0-9A-F]{8}) " # mem
- "([0-9A-F]{2}:[0-9A-F]{8}) " # ?
- "([0-9A-F]{8}) +" # ?
- "([0-9]+) +" # uid
- "([0-9]+) +" # timeout
- "([0-9]+) +" # inode
- "([0-9]+) +" # refcnt
- "([0-9a-f]+)" # sp
- "%s" # icmp has spaces
- % (addrlen, addrlen, end_regexp))
- # Return a list of lists with only source / dest addresses for now.
- # TODO: consider returning a dict or namedtuple instead.
- out = []
- for line in lines:
- (_, src, dst, state, mem,
- _, _, uid, _, _, refcnt, _, extra) = regexp.match(line).groups()
- out.append([src, dst, state, mem, uid, refcnt, extra])
- return out
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/net_test.sh b/tests/net_test/net_test.sh
deleted file mode 100755
index acac660..0000000
--- a/tests/net_test/net_test.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-
-# In case IPv6 is compiled as a module.
-[ -f /proc/net/if_inet6 ] || insmod $DIR/kernel/net-next/net/ipv6/ipv6.ko
-
-# Minimal network setup.
-ip link set lo up
-ip link set lo mtu 16436
-ip link set eth0 up
-
-# Allow people to run ping.
-echo "0 65536" > /proc/sys/net/ipv4/ping_group_range
-
-# Fall out to a shell once the test completes or if there's an error.
-trap "exec /bin/bash" ERR EXIT
-
-# Find and run the test.
-test=$(cat /proc/cmdline | sed -re 's/.*net_test=([^ ]*).*/\1/g')
-echo -e "Running $test\n"
-$test
diff --git a/tests/net_test/netlink.py b/tests/net_test/netlink.py
deleted file mode 100644
index 2b8f744..0000000
--- a/tests/net_test/netlink.py
+++ /dev/null
@@ -1,255 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Partial Python implementation of iproute functionality."""
-
-# pylint: disable=g-bad-todo
-
-import errno
-import os
-import socket
-import struct
-import sys
-
-import cstruct
-
-
-# Request constants.
-NLM_F_REQUEST = 1
-NLM_F_ACK = 4
-NLM_F_REPLACE = 0x100
-NLM_F_EXCL = 0x200
-NLM_F_CREATE = 0x400
-NLM_F_DUMP = 0x300
-
-# Message types.
-NLMSG_ERROR = 2
-NLMSG_DONE = 3
-
-# Data structure formats.
-# These aren't constants, they're classes. So, pylint: disable=invalid-name
-NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
-NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
-NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
-
-# Alignment / padding.
-NLA_ALIGNTO = 4
-
-
-def PaddedLength(length):
- # TODO: This padding is probably overly simplistic.
- return NLA_ALIGNTO * ((length / NLA_ALIGNTO) + (length % NLA_ALIGNTO != 0))
-
-
-class NetlinkSocket(object):
- """A basic netlink socket object."""
-
- BUFSIZE = 65536
- DEBUG = False
- # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"]
- NL_DEBUG = []
-
- def _Debug(self, s):
- if self.DEBUG:
- print s
-
- def _NlAttr(self, nla_type, data):
- datalen = len(data)
- # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long.
- padding = "\x00" * (PaddedLength(datalen) - datalen)
- nla_len = datalen + len(NLAttr)
- return NLAttr((nla_len, nla_type)).Pack() + data + padding
-
- def _NlAttrU32(self, nla_type, value):
- return self._NlAttr(nla_type, struct.pack("=I", value))
-
- def _GetConstantName(self, module, value, prefix):
- thismodule = sys.modules[module]
- for name in dir(thismodule):
- if name.startswith("INET_DIAG_BC"):
- break
- if (name.startswith(prefix) and
- not name.startswith(prefix + "F_") and
- name.isupper() and getattr(thismodule, name) == value):
- return name
- return value
-
- def _Decode(self, command, msg, nla_type, nla_data):
- """No-op, nonspecific version of decode."""
- return nla_type, nla_data
-
- def _ParseAttributes(self, command, family, msg, data):
- """Parses and decodes netlink attributes.
-
- Takes a block of NLAttr data structures, decodes them using Decode, and
- returns the result in a dict keyed by attribute number.
-
- Args:
- command: An integer, the rtnetlink command being carried out.
- family: The address family.
- msg: A Struct, the type of the data after the netlink header.
- data: A byte string containing a sequence of NLAttr data structures.
-
- Returns:
- A dictionary mapping attribute types (integers) to decoded values.
-
- Raises:
- ValueError: There was a duplicate attribute type.
- """
- attributes = {}
- while data:
- # Read the nlattr header.
- nla, data = cstruct.Read(data, NLAttr)
-
- # Read the data.
- datalen = nla.nla_len - len(nla)
- padded_len = PaddedLength(nla.nla_len) - len(nla)
- nla_data, data = data[:datalen], data[padded_len:]
-
- # If it's an attribute we know about, try to decode it.
- nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data)
-
- # We only support unique attributes for now, except for INET_DIAG_NONE,
- # which can appear more than once but doesn't seem to contain any data.
- if nla_name in attributes and nla_name != "INET_DIAG_NONE":
- raise ValueError("Duplicate attribute %s" % nla_name)
-
- attributes[nla_name] = nla_data
- self._Debug(" %s" % str((nla_name, nla_data)))
-
- return attributes
-
- def __init__(self):
- # Global sequence number.
- self.seq = 0
- self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.FAMILY)
- self.sock.connect((0, 0)) # The kernel.
- self.pid = self.sock.getsockname()[1]
-
- def _Send(self, msg):
- # self._Debug(msg.encode("hex"))
- self.seq += 1
- self.sock.send(msg)
-
- def _Recv(self):
- data = self.sock.recv(self.BUFSIZE)
- # self._Debug(data.encode("hex"))
- return data
-
- def _ExpectDone(self):
- response = self._Recv()
- hdr = NLMsgHdr(response)
- if hdr.type != NLMSG_DONE:
- raise ValueError("Expected DONE, got type %d" % hdr.type)
-
- def _ParseAck(self, response):
- # Find the error code.
- hdr, data = cstruct.Read(response, NLMsgHdr)
- if hdr.type == NLMSG_ERROR:
- error = NLMsgErr(data).error
- if error:
- raise IOError(error, os.strerror(-error))
- else:
- raise ValueError("Expected ACK, got type %d" % hdr.type)
-
- def _ExpectAck(self):
- response = self._Recv()
- self._ParseAck(response)
-
- def _SendNlRequest(self, command, data, flags):
- """Sends a netlink request and expects an ack."""
- length = len(NLMsgHdr) + len(data)
- nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()
-
- self.MaybeDebugCommand(command, nlmsg + data)
-
- # Send the message.
- self._Send(nlmsg + data)
-
- if flags & NLM_F_ACK:
- self._ExpectAck()
-
- def _ParseNLMsg(self, data, msgtype):
- """Parses a Netlink message into a header and a dictionary of attributes."""
- nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
- self._Debug(" %s" % nlmsghdr)
-
- if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE:
- print "done"
- return (None, None), data
-
- nlmsg, data = cstruct.Read(data, msgtype)
- self._Debug(" %s" % nlmsg)
-
- # Parse the attributes in the nlmsg.
- attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg)
- attributes = self._ParseAttributes(nlmsghdr.type, nlmsg.family,
- nlmsg, data[:attrlen])
- data = data[attrlen:]
- return (nlmsg, attributes), data
-
- def _GetMsg(self, msgtype):
- data = self._Recv()
- if NLMsgHdr(data).type == NLMSG_ERROR:
- self._ParseAck(data)
- return self._ParseNLMsg(data, msgtype)[0]
-
- def _GetMsgList(self, msgtype, data, expect_done):
- out = []
- while data:
- msg, data = self._ParseNLMsg(data, msgtype)
- if msg is None:
- break
- out.append(msg)
- if expect_done:
- self._ExpectDone()
- return out
-
- def _Dump(self, command, msg, msgtype, attrs):
- """Sends a dump request and returns a list of decoded messages.
-
- Args:
- command: An integer, the command to run (e.g., RTM_NEWADDR).
- msg: A string, the raw bytes of the request (e.g., a packed RTMsg).
- msgtype: A cstruct.Struct, the data type to parse the dump results as.
- attrs: A string, the raw bytes of any request attributes to include.
-
- Returns:
- A list of (msg, attrs) tuples where msg is of type msgtype and attrs is
- a dict of attributes.
- """
- # Create a netlink dump request containing the msg.
- flags = NLM_F_DUMP | NLM_F_REQUEST
- length = len(NLMsgHdr) + len(msg) + len(attrs)
- nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))
-
- # Send the request.
- self._Send(nlmsghdr.Pack() + msg.Pack() + attrs)
-
- # Keep reading netlink messages until we get a NLMSG_DONE.
- out = []
- while True:
- data = self._Recv()
- response_type = NLMsgHdr(data).type
- if response_type == NLMSG_DONE:
- break
- elif response_type == NLMSG_ERROR:
- # Likely means that the kernel didn't like our dump request.
- # Parse the error and throw an exception.
- self._ParseAck(data)
- out.extend(self._GetMsgList(msgtype, data, False))
-
- return out
diff --git a/tests/net_test/packets.py b/tests/net_test/packets.py
deleted file mode 100644
index c02adc0..0000000
--- a/tests/net_test/packets.py
+++ /dev/null
@@ -1,197 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2015 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import random
-
-from scapy import all as scapy
-from socket import *
-
-import net_test
-
-TCP_FIN = 1
-TCP_SYN = 2
-TCP_RST = 4
-TCP_PSH = 8
-TCP_ACK = 16
-
-TCP_SEQ = 1692871236
-TCP_WINDOW = 14400
-
-PING_IDENT = 0xff19
-PING_PAYLOAD = "foobarbaz"
-PING_SEQ = 3
-PING_TOS = 0x83
-
-# For brevity.
-UDP_PAYLOAD = net_test.UDP_PAYLOAD
-
-
-def RandomPort():
- return random.randint(1025, 65535)
-
-def _GetIpLayer(version):
- return {4: scapy.IP, 6: scapy.IPv6}[version]
-
-def _SetPacketTos(packet, tos):
- if isinstance(packet, scapy.IPv6):
- packet.tc = tos
- elif isinstance(packet, scapy.IP):
- packet.tos = tos
- else:
- raise ValueError("Can't find ToS Field")
-
-def UDP(version, srcaddr, dstaddr, sport=0):
- ip = _GetIpLayer(version)
- # Can't just use "if sport" because None has meaning (it means unspecified).
- if sport == 0:
- sport = RandomPort()
- return ("UDPv%d packet" % version,
- ip(src=srcaddr, dst=dstaddr) /
- scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
-
-def UDPWithOptions(version, srcaddr, dstaddr, sport=0):
- if version == 4:
- packet = (scapy.IP(src=srcaddr, dst=dstaddr, ttl=39, tos=0x83) /
- scapy.UDP(sport=sport, dport=53) /
- UDP_PAYLOAD)
- else:
- packet = (scapy.IPv6(src=srcaddr, dst=dstaddr,
- fl=0xbeef, hlim=39, tc=0x83) /
- scapy.UDP(sport=sport, dport=53) /
- UDP_PAYLOAD)
- return ("UDPv%d packet with options" % version, packet)
-
-def SYN(dport, version, srcaddr, dstaddr, sport=0, seq=TCP_SEQ):
- ip = _GetIpLayer(version)
- if sport == 0:
- sport = RandomPort()
- return ("TCP SYN",
- ip(src=srcaddr, dst=dstaddr) /
- scapy.TCP(sport=sport, dport=dport,
- seq=seq, ack=0,
- flags=TCP_SYN, window=TCP_WINDOW))
-
-def RST(version, srcaddr, dstaddr, packet):
- ip = _GetIpLayer(version)
- original = packet.getlayer("TCP")
- was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0
- return ("TCP RST",
- ip(src=srcaddr, dst=dstaddr) /
- scapy.TCP(sport=original.dport, dport=original.sport,
- ack=original.seq + was_syn_or_fin, seq=None,
- flags=TCP_RST | TCP_ACK, window=TCP_WINDOW))
-
-def SYNACK(version, srcaddr, dstaddr, packet):
- ip = _GetIpLayer(version)
- original = packet.getlayer("TCP")
- return ("TCP SYN+ACK",
- ip(src=srcaddr, dst=dstaddr) /
- scapy.TCP(sport=original.dport, dport=original.sport,
- ack=original.seq + 1, seq=None,
- flags=TCP_SYN | TCP_ACK, window=None))
-
-def ACK(version, srcaddr, dstaddr, packet, payload=""):
- ip = _GetIpLayer(version)
- original = packet.getlayer("TCP")
- was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0
- ack_delta = was_syn_or_fin + len(original.payload)
- desc = "TCP data" if payload else "TCP ACK"
- flags = TCP_ACK | TCP_PSH if payload else TCP_ACK
- return (desc,
- ip(src=srcaddr, dst=dstaddr) /
- scapy.TCP(sport=original.dport, dport=original.sport,
- ack=original.seq + ack_delta, seq=original.ack,
- flags=flags, window=TCP_WINDOW) /
- payload)
-
-def FIN(version, srcaddr, dstaddr, packet):
- ip = _GetIpLayer(version)
- original = packet.getlayer("TCP")
- was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0
- ack_delta = was_syn_or_fin + len(original.payload)
- return ("TCP FIN",
- ip(src=srcaddr, dst=dstaddr) /
- scapy.TCP(sport=original.dport, dport=original.sport,
- ack=original.seq + ack_delta, seq=original.ack,
- flags=TCP_ACK | TCP_FIN, window=TCP_WINDOW))
-
-def GRE(version, srcaddr, dstaddr, proto, packet):
- if version == 4:
- ip = scapy.IP(src=srcaddr, dst=dstaddr, proto=net_test.IPPROTO_GRE)
- else:
- ip = scapy.IPv6(src=srcaddr, dst=dstaddr, nh=net_test.IPPROTO_GRE)
- packet = ip / scapy.GRE(proto=proto) / packet
- return ("GRE packet", packet)
-
-def ICMPPortUnreachable(version, srcaddr, dstaddr, packet):
- if version == 4:
- # Linux hardcodes the ToS on ICMP errors to 0xc0 or greater because of
- # RFC 1812 4.3.2.5 (!).
- return ("ICMPv4 port unreachable",
- scapy.IP(src=srcaddr, dst=dstaddr, proto=1, tos=0xc0) /
- scapy.ICMPerror(type=3, code=3) / packet)
- else:
- return ("ICMPv6 port unreachable",
- scapy.IPv6(src=srcaddr, dst=dstaddr) /
- scapy.ICMPv6DestUnreach(code=4) / packet)
-
-def ICMPPacketTooBig(version, srcaddr, dstaddr, packet):
- if version == 4:
- return ("ICMPv4 fragmentation needed",
- scapy.IP(src=srcaddr, dst=dstaddr, proto=1) /
- scapy.ICMPerror(type=3, code=4, unused=1280) / str(packet)[:64])
- else:
- udp = packet.getlayer("UDP")
- udp.payload = str(udp.payload)[:1280-40-8]
- return ("ICMPv6 Packet Too Big",
- scapy.IPv6(src=srcaddr, dst=dstaddr) /
- scapy.ICMPv6PacketTooBig() / str(packet)[:1232])
-
-def ICMPEcho(version, srcaddr, dstaddr):
- ip = _GetIpLayer(version)
- icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
- packet = (ip(src=srcaddr, dst=dstaddr) /
- icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
- _SetPacketTos(packet, PING_TOS)
- return ("ICMPv%d echo" % version, packet)
-
-def ICMPReply(version, srcaddr, dstaddr, packet):
- ip = _GetIpLayer(version)
- # Scapy doesn't provide an ICMP echo reply constructor.
- icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
- icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
- packet = (ip(src=srcaddr, dst=dstaddr) /
- icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
- # IPv6 only started copying the tclass to echo replies in 3.14.
- if version == 4 or net_test.LINUX_VERSION >= (3, 14):
- _SetPacketTos(packet, PING_TOS)
- return ("ICMPv%d echo reply" % version, packet)
-
-def NS(srcaddr, tgtaddr, srcmac):
- solicited = inet_pton(AF_INET6, tgtaddr)
- last3bytes = tuple([ord(b) for b in solicited[-3:]])
- solicited = "ff02::1:ff%02x:%02x%02x" % last3bytes
- packet = (scapy.IPv6(src=srcaddr, dst=solicited) /
- scapy.ICMPv6ND_NS(tgt=tgtaddr) /
- scapy.ICMPv6NDOptSrcLLAddr(lladdr=srcmac))
- return ("ICMPv6 NS", packet)
-
-def NA(srcaddr, dstaddr, srcmac):
- packet = (scapy.IPv6(src=srcaddr, dst=dstaddr) /
- scapy.ICMPv6ND_NA(tgt=srcaddr, R=0, S=1, O=1) /
- scapy.ICMPv6NDOptDstLLAddr(lladdr=srcmac))
- return ("ICMPv6 NA", packet)
-
diff --git a/tests/net_test/ping6_test.py b/tests/net_test/ping6_test.py
deleted file mode 100755
index bf51cfa..0000000
--- a/tests/net_test/ping6_test.py
+++ /dev/null
@@ -1,709 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# pylint: disable=g-bad-todo
-
-import errno
-import os
-import posix
-import random
-from socket import * # pylint: disable=wildcard-import
-import threading
-import time
-import unittest
-
-from scapy import all as scapy
-
-import csocket
-import multinetwork_base
-import net_test
-
-
-HAVE_PROC_NET_ICMP6 = os.path.isfile("/proc/net/icmp6")
-
-ICMP_ECHO = 8
-ICMP_ECHOREPLY = 0
-ICMPV6_ECHO_REQUEST = 128
-ICMPV6_ECHO_REPLY = 129
-
-
-class PingReplyThread(threading.Thread):
-
- MIN_TTL = 10
- INTERMEDIATE_IPV4 = "192.0.2.2"
- INTERMEDIATE_IPV6 = "2001:db8:1:2::ace:d00d"
- NEIGHBOURS = ["fe80::1"]
-
- def __init__(self, tun, mymac, routermac):
- super(PingReplyThread, self).__init__()
- self._tun = tun
- self._stopped = False
- self._mymac = mymac
- self._routermac = routermac
-
- def Stop(self):
- self._stopped = True
-
- def ChecksumValid(self, packet):
- # Get and clear the checksums.
- def GetAndClearChecksum(layer):
- if not layer:
- return
- try:
- checksum = layer.chksum
- del layer.chksum
- except AttributeError:
- checksum = layer.cksum
- del layer.cksum
- return checksum
-
- def GetChecksum(layer):
- try:
- return layer.chksum
- except AttributeError:
- return layer.cksum
-
- layers = ["IP", "ICMP", scapy.ICMPv6EchoRequest]
- sums = {}
- for name in layers:
- sums[name] = GetAndClearChecksum(packet.getlayer(name))
-
- # Serialize the packet, so scapy recalculates the checksums, and compare
- # them with the ones in the packet.
- packet = packet.__class__(str(packet))
- for name in layers:
- layer = packet.getlayer(name)
- if layer and GetChecksum(layer) != sums[name]:
- return False
-
- return True
-
- def SendTimeExceeded(self, version, packet):
- if version == 4:
- src = packet.getlayer(scapy.IP).src
- self.SendPacket(
- scapy.IP(src=self.INTERMEDIATE_IPV4, dst=src) /
- scapy.ICMP(type=11, code=0) /
- packet)
- elif version == 6:
- src = packet.getlayer(scapy.IPv6).src
- self.SendPacket(
- scapy.IPv6(src=self.INTERMEDIATE_IPV6, dst=src) /
- scapy.ICMPv6TimeExceeded(code=0) /
- packet)
-
- def IPv4Packet(self, ip):
- icmp = ip.getlayer(scapy.ICMP)
-
- # We only support ping for now.
- if (ip.proto != IPPROTO_ICMP or
- icmp.type != ICMP_ECHO or
- icmp.code != 0):
- return
-
- # Check the checksums.
- if not self.ChecksumValid(ip):
- return
-
- if ip.ttl < self.MIN_TTL:
- self.SendTimeExceeded(4, ip)
- return
-
- icmp.type = ICMP_ECHOREPLY
- self.SwapAddresses(ip)
- self.SendPacket(ip)
-
- def IPv6Packet(self, ipv6):
- icmpv6 = ipv6.getlayer(scapy.ICMPv6EchoRequest)
-
- # We only support ping for now.
- if (ipv6.nh != IPPROTO_ICMPV6 or
- not icmpv6 or
- icmpv6.type != ICMPV6_ECHO_REQUEST or
- icmpv6.code != 0):
- return
-
- # Check the checksums.
- if not self.ChecksumValid(ipv6):
- return
-
- if ipv6.dst.startswith("ff02::"):
- ipv6.dst = ipv6.src
- for src in self.NEIGHBOURS:
- ipv6.src = src
- icmpv6.type = ICMPV6_ECHO_REPLY
- self.SendPacket(ipv6)
- elif ipv6.hlim < self.MIN_TTL:
- self.SendTimeExceeded(6, ipv6)
- else:
- icmpv6.type = ICMPV6_ECHO_REPLY
- self.SwapAddresses(ipv6)
- self.SendPacket(ipv6)
-
- def SwapAddresses(self, packet):
- src = packet.src
- packet.src = packet.dst
- packet.dst = src
-
- def SendPacket(self, packet):
- packet = scapy.Ether(src=self._routermac, dst=self._mymac) / packet
- try:
- posix.write(self._tun.fileno(), str(packet))
- except ValueError:
- pass
-
- def run(self):
- while not self._stopped:
-
- try:
- packet = posix.read(self._tun.fileno(), 4096)
- except OSError, e:
- if e.errno == errno.EAGAIN:
- continue
- else:
- break
-
- ether = scapy.Ether(packet)
- if ether.type == net_test.ETH_P_IPV6:
- self.IPv6Packet(ether.payload)
- elif ether.type == net_test.ETH_P_IP:
- self.IPv4Packet(ether.payload)
-
-
-class Ping6Test(multinetwork_base.MultiNetworkBaseTest):
-
- @classmethod
- def setUpClass(cls):
- super(Ping6Test, cls).setUpClass()
- cls.netid = random.choice(cls.NETIDS)
- cls.reply_thread = PingReplyThread(
- cls.tuns[cls.netid],
- cls.MyMacAddress(cls.netid),
- cls.RouterMacAddress(cls.netid))
- cls.SetDefaultNetwork(cls.netid)
- cls.reply_thread.start()
-
- @classmethod
- def tearDownClass(cls):
- cls.reply_thread.Stop()
- cls.ClearDefaultNetwork()
- super(Ping6Test, cls).tearDownClass()
-
- def setUp(self):
- self.ifname = self.GetInterfaceName(self.netid)
- self.ifindex = self.ifindices[self.netid]
- self.lladdr = net_test.GetLinkAddress(self.ifname, True)
- self.globaladdr = net_test.GetLinkAddress(self.ifname, False)
-
- def assertValidPingResponse(self, s, data):
- family = s.family
-
- # Receive the reply.
- rcvd, src = s.recvfrom(32768)
- self.assertNotEqual(0, len(rcvd), "No data received")
-
- # If this is a dual-stack socket sending to a mapped IPv4 address, treat it
- # as IPv4.
- if src[0].startswith("::ffff:"):
- family = AF_INET
- src = (src[0].replace("::ffff:", ""), src[1:])
-
- # Check the data being sent is valid.
- self.assertGreater(len(data), 7, "Not enough data for ping packet")
- if family == AF_INET:
- self.assertTrue(data.startswith("\x08\x00"), "Not an IPv4 echo request")
- elif family == AF_INET6:
- self.assertTrue(data.startswith("\x80\x00"), "Not an IPv6 echo request")
- else:
- self.fail("Unknown socket address family %d" * s.family)
-
- # Check address, ICMP type, and ICMP code.
- if family == AF_INET:
- addr, unused_port = src
- self.assertGreaterEqual(len(addr), len("1.1.1.1"))
- self.assertTrue(rcvd.startswith("\x00\x00"), "Not an IPv4 echo reply")
- else:
- addr, unused_port, flowlabel, scope_id = src # pylint: disable=unbalanced-tuple-unpacking
- self.assertGreaterEqual(len(addr), len("::"))
- self.assertTrue(rcvd.startswith("\x81\x00"), "Not an IPv6 echo reply")
- # Check that the flow label is zero and that the scope ID is sane.
- self.assertEqual(flowlabel, 0)
- if addr.startswith("fe80::"):
- self.assertTrue(scope_id in self.ifindices.values())
- else:
- self.assertEquals(0, scope_id)
-
- # TODO: check the checksum. We can't do this easily now for ICMPv6 because
- # we don't have the IP addresses so we can't construct the pseudoheader.
-
- # Check the sequence number and the data.
- self.assertEqual(len(data), len(rcvd))
- self.assertEqual(data[6:].encode("hex"), rcvd[6:].encode("hex"))
-
- def CheckSockStatFile(self, name, srcaddr, srcport, dstaddr, dstport, state,
- txmem=0, rxmem=0):
- expected = ["%s:%04X" % (net_test.FormatSockStatAddress(srcaddr), srcport),
- "%s:%04X" % (net_test.FormatSockStatAddress(dstaddr), dstport),
- "%02X" % state,
- "%08X:%08X" % (txmem, rxmem),
- str(os.getuid()), "2", "0"]
- actual = self.ReadProcNetSocket(name)[-1]
- self.assertListEqual(expected, actual)
-
- def testIPv4SendWithNoConnection(self):
- s = net_test.IPv4PingSocket()
- self.assertRaisesErrno(errno.EDESTADDRREQ, s.send, net_test.IPV4_PING)
-
- def testIPv6SendWithNoConnection(self):
- s = net_test.IPv6PingSocket()
- self.assertRaisesErrno(errno.EDESTADDRREQ, s.send, net_test.IPV6_PING)
-
- def testIPv4LoopbackPingWithConnect(self):
- s = net_test.IPv4PingSocket()
- s.connect(("127.0.0.1", 55))
- data = net_test.IPV4_PING + "foobarbaz"
- s.send(data)
- self.assertValidPingResponse(s, data)
-
- def testIPv6LoopbackPingWithConnect(self):
- s = net_test.IPv6PingSocket()
- s.connect(("::1", 55))
- s.send(net_test.IPV6_PING)
- self.assertValidPingResponse(s, net_test.IPV6_PING)
-
- def testIPv4PingUsingSendto(self):
- s = net_test.IPv4PingSocket()
- written = s.sendto(net_test.IPV4_PING, (net_test.IPV4_ADDR, 55))
- self.assertEquals(len(net_test.IPV4_PING), written)
- self.assertValidPingResponse(s, net_test.IPV4_PING)
-
- def testIPv6PingUsingSendto(self):
- s = net_test.IPv6PingSocket()
- written = s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55))
- self.assertEquals(len(net_test.IPV6_PING), written)
- self.assertValidPingResponse(s, net_test.IPV6_PING)
-
- def testIPv4NoCrash(self):
- # Python 2.x does not provide either read() or recvmsg.
- s = net_test.IPv4PingSocket()
- written = s.sendto(net_test.IPV4_PING, ("127.0.0.1", 55))
- self.assertEquals(len(net_test.IPV4_PING), written)
- fd = s.fileno()
- reply = posix.read(fd, 4096)
- self.assertEquals(written, len(reply))
-
- def testIPv6NoCrash(self):
- # Python 2.x does not provide either read() or recvmsg.
- s = net_test.IPv6PingSocket()
- written = s.sendto(net_test.IPV6_PING, ("::1", 55))
- self.assertEquals(len(net_test.IPV6_PING), written)
- fd = s.fileno()
- reply = posix.read(fd, 4096)
- self.assertEquals(written, len(reply))
-
- def testCrossProtocolCrash(self):
- # Checks that an ICMP error containing a ping packet that matches the ID
- # of a socket of the wrong protocol (which can happen when using 464xlat)
- # doesn't crash the kernel.
-
- # We can only test this using IPv6 unreachables and IPv4 ping sockets,
- # because IPv4 packets sent by scapy.send() on loopback are not received by
- # the kernel. So we don't actually use this function yet.
- def GetIPv4Unreachable(port): # pylint: disable=unused-variable
- return (scapy.IP(src="192.0.2.1", dst="127.0.0.1") /
- scapy.ICMP(type=3, code=0) /
- scapy.IP(src="127.0.0.1", dst="127.0.0.1") /
- scapy.ICMP(type=8, id=port, seq=1))
-
- def GetIPv6Unreachable(port):
- return (scapy.IPv6(src="::1", dst="::1") /
- scapy.ICMPv6DestUnreach() /
- scapy.IPv6(src="::1", dst="::1") /
- scapy.ICMPv6EchoRequest(id=port, seq=1, data="foobarbaz"))
-
- # An unreachable matching the ID of a socket of the wrong protocol
- # shouldn't crash.
- s = net_test.IPv4PingSocket()
- s.connect(("127.0.0.1", 12345))
- _, port = s.getsockname()
- scapy.send(GetIPv6Unreachable(port))
- # No crash? Good.
-
- def testCrossProtocolCalls(self):
- """Tests that passing in the wrong family returns EAFNOSUPPORT.
-
- Relevant kernel commits:
- upstream net:
- 91a0b60 net/ping: handle protocol mismatching scenario
- 9145736d net: ping: Return EAFNOSUPPORT when appropriate.
-
- android-3.10:
- 78a6809 net/ping: handle protocol mismatching scenario
- 428e6d6 net: ping: Return EAFNOSUPPORT when appropriate.
- """
-
- def CheckEAFNoSupport(function, *args):
- self.assertRaisesErrno(errno.EAFNOSUPPORT, function, *args)
-
- ipv6sockaddr = csocket.Sockaddr((net_test.IPV6_ADDR, 53))
-
- # In order to check that IPv6 socket calls return EAFNOSUPPORT when passed
- # IPv4 socket address structures, we need to pass down a socket address
- # length argument that's at least sizeof(sockaddr_in6). Otherwise, the calls
- # will fail immediately with EINVAL because the passed-in socket length is
- # too short. So create a sockaddr_in that's as long as a sockaddr_in6.
- ipv4sockaddr = csocket.Sockaddr((net_test.IPV4_ADDR, 53))
- ipv4sockaddr = csocket.SockaddrIn6(
- ipv4sockaddr.Pack() +
- "\x00" * (len(csocket.SockaddrIn6) - len(csocket.SockaddrIn)))
-
- s4 = net_test.IPv4PingSocket()
- s6 = net_test.IPv6PingSocket()
-
- # We can't just call s.connect(), s.bind() etc. with a tuple of the wrong
- # address family, because the Python implementation will just pass garbage
- # down to the kernel. So call the C functions directly.
- CheckEAFNoSupport(csocket.Bind, s4, ipv6sockaddr)
- CheckEAFNoSupport(csocket.Bind, s6, ipv4sockaddr)
- CheckEAFNoSupport(csocket.Connect, s4, ipv6sockaddr)
- CheckEAFNoSupport(csocket.Connect, s6, ipv4sockaddr)
- CheckEAFNoSupport(csocket.Sendmsg,
- s4, ipv6sockaddr, net_test.IPV4_PING, None, 0)
- CheckEAFNoSupport(csocket.Sendmsg,
- s6, ipv4sockaddr, net_test.IPV6_PING, None, 0)
-
- def testIPv4Bind(self):
- # Bind to unspecified address.
- s = net_test.IPv4PingSocket()
- s.bind(("0.0.0.0", 544))
- self.assertEquals(("0.0.0.0", 544), s.getsockname())
-
- # Bind to loopback.
- s = net_test.IPv4PingSocket()
- s.bind(("127.0.0.1", 99))
- self.assertEquals(("127.0.0.1", 99), s.getsockname())
-
- # Binding twice is not allowed.
- self.assertRaisesErrno(errno.EINVAL, s.bind, ("127.0.0.1", 22))
-
- # But binding two different sockets to the same ID is allowed.
- s2 = net_test.IPv4PingSocket()
- s2.bind(("127.0.0.1", 99))
- self.assertEquals(("127.0.0.1", 99), s2.getsockname())
- s3 = net_test.IPv4PingSocket()
- s3.bind(("127.0.0.1", 99))
- self.assertEquals(("127.0.0.1", 99), s3.getsockname())
-
- # If two sockets bind to the same port, the first one to call read() gets
- # the response.
- s4 = net_test.IPv4PingSocket()
- s5 = net_test.IPv4PingSocket()
- s4.bind(("0.0.0.0", 167))
- s5.bind(("0.0.0.0", 167))
- s4.sendto(net_test.IPV4_PING, (net_test.IPV4_ADDR, 44))
- self.assertValidPingResponse(s5, net_test.IPV4_PING)
- net_test.SetSocketTimeout(s4, 100)
- self.assertRaisesErrno(errno.EAGAIN, s4.recv, 32768)
-
- # If SO_REUSEADDR is turned off, then we get EADDRINUSE.
- s6 = net_test.IPv4PingSocket()
- s4.setsockopt(SOL_SOCKET, SO_REUSEADDR, 0)
- self.assertRaisesErrno(errno.EADDRINUSE, s6.bind, ("0.0.0.0", 167))
-
- # Can't bind after sendto.
- s = net_test.IPv4PingSocket()
- s.sendto(net_test.IPV4_PING, (net_test.IPV4_ADDR, 9132))
- self.assertRaisesErrno(errno.EINVAL, s.bind, ("0.0.0.0", 5429))
-
- def testIPv6Bind(self):
- # Bind to unspecified address.
- s = net_test.IPv6PingSocket()
- s.bind(("::", 769))
- self.assertEquals(("::", 769, 0, 0), s.getsockname())
-
- # Bind to loopback.
- s = net_test.IPv6PingSocket()
- s.bind(("::1", 99))
- self.assertEquals(("::1", 99, 0, 0), s.getsockname())
-
- # Binding twice is not allowed.
- self.assertRaisesErrno(errno.EINVAL, s.bind, ("::1", 22))
-
- # But binding two different sockets to the same ID is allowed.
- s2 = net_test.IPv6PingSocket()
- s2.bind(("::1", 99))
- self.assertEquals(("::1", 99, 0, 0), s2.getsockname())
- s3 = net_test.IPv6PingSocket()
- s3.bind(("::1", 99))
- self.assertEquals(("::1", 99, 0, 0), s3.getsockname())
-
- # Binding both IPv4 and IPv6 to the same socket works.
- s4 = net_test.IPv4PingSocket()
- s6 = net_test.IPv6PingSocket()
- s4.bind(("0.0.0.0", 444))
- s6.bind(("::", 666, 0, 0))
-
- # Can't bind after sendto.
- s = net_test.IPv6PingSocket()
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 9132))
- self.assertRaisesErrno(errno.EINVAL, s.bind, ("::", 5429))
-
- def testIPv4InvalidBind(self):
- s = net_test.IPv4PingSocket()
- self.assertRaisesErrno(errno.EADDRNOTAVAIL,
- s.bind, ("255.255.255.255", 1026))
- self.assertRaisesErrno(errno.EADDRNOTAVAIL,
- s.bind, ("224.0.0.1", 651))
- # Binding to an address we don't have only works with IP_TRANSPARENT.
- self.assertRaisesErrno(errno.EADDRNOTAVAIL,
- s.bind, (net_test.IPV4_ADDR, 651))
- try:
- s.setsockopt(SOL_IP, net_test.IP_TRANSPARENT, 1)
- s.bind((net_test.IPV4_ADDR, 651))
- except IOError, e:
- if e.errno == errno.EACCES:
- pass # We're not root. let it go for now.
-
- def testIPv6InvalidBind(self):
- s = net_test.IPv6PingSocket()
- self.assertRaisesErrno(errno.EINVAL,
- s.bind, ("ff02::2", 1026))
-
- # Binding to an address we don't have only works with IPV6_TRANSPARENT.
- self.assertRaisesErrno(errno.EADDRNOTAVAIL,
- s.bind, (net_test.IPV6_ADDR, 651))
- try:
- s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_TRANSPARENT, 1)
- s.bind((net_test.IPV6_ADDR, 651))
- except IOError, e:
- if e.errno == errno.EACCES:
- pass # We're not root. let it go for now.
-
- def testAfUnspecBind(self):
- # Binding to AF_UNSPEC is treated as IPv4 if the address is 0.0.0.0.
- s4 = net_test.IPv4PingSocket()
- sockaddr = csocket.Sockaddr(("0.0.0.0", 12996))
- sockaddr.family = AF_UNSPEC
- csocket.Bind(s4, sockaddr)
- self.assertEquals(("0.0.0.0", 12996), s4.getsockname())
-
- # But not if the address is anything else.
- sockaddr = csocket.Sockaddr(("127.0.0.1", 58234))
- sockaddr.family = AF_UNSPEC
- self.assertRaisesErrno(errno.EAFNOSUPPORT, csocket.Bind, s4, sockaddr)
-
- # This doesn't work for IPv6.
- s6 = net_test.IPv6PingSocket()
- sockaddr = csocket.Sockaddr(("::1", 58997))
- sockaddr.family = AF_UNSPEC
- self.assertRaisesErrno(errno.EAFNOSUPPORT, csocket.Bind, s6, sockaddr)
-
- def testIPv6ScopedBind(self):
- # Can't bind to a link-local address without a scope ID.
- s = net_test.IPv6PingSocket()
- self.assertRaisesErrno(errno.EINVAL,
- s.bind, (self.lladdr, 1026, 0, 0))
-
- # Binding to a link-local address with a scope ID works, and the scope ID is
- # returned by a subsequent getsockname. Interestingly, Python's getsockname
- # returns "fe80:1%foo", even though it does not understand it.
- expected = self.lladdr + "%" + self.ifname
- s.bind((self.lladdr, 4646, 0, self.ifindex))
- self.assertEquals((expected, 4646, 0, self.ifindex), s.getsockname())
-
- # Of course, for the above to work the address actually has to be configured
- # on the machine.
- self.assertRaisesErrno(errno.EADDRNOTAVAIL,
- s.bind, ("fe80::f00", 1026, 0, 1))
-
- # Scope IDs on non-link-local addresses are silently ignored.
- s = net_test.IPv6PingSocket()
- s.bind(("::1", 1234, 0, 1))
- self.assertEquals(("::1", 1234, 0, 0), s.getsockname())
-
- def testBindAffectsIdentifier(self):
- s = net_test.IPv6PingSocket()
- s.bind((self.globaladdr, 0xf976))
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55))
- self.assertEquals("\xf9\x76", s.recv(32768)[4:6])
-
- s = net_test.IPv6PingSocket()
- s.bind((self.globaladdr, 0xace))
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55))
- self.assertEquals("\x0a\xce", s.recv(32768)[4:6])
-
- def testLinkLocalAddress(self):
- s = net_test.IPv6PingSocket()
- # Sending to a link-local address with no scope fails with EINVAL.
- self.assertRaisesErrno(errno.EINVAL,
- s.sendto, net_test.IPV6_PING, ("fe80::1", 55))
- # Sending to link-local address with a scope succeeds. Note that Python
- # doesn't understand the "fe80::1%lo" format, even though it returns it.
- s.sendto(net_test.IPV6_PING, ("fe80::1", 55, 0, self.ifindex))
- # No exceptions? Good.
-
- def testMappedAddressFails(self):
- s = net_test.IPv6PingSocket()
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55))
- self.assertValidPingResponse(s, net_test.IPV6_PING)
- s.sendto(net_test.IPV6_PING, ("2001:4860:4860::8844", 55))
- self.assertValidPingResponse(s, net_test.IPV6_PING)
- self.assertRaisesErrno(errno.EINVAL, s.sendto, net_test.IPV6_PING,
- ("::ffff:192.0.2.1", 55))
-
- @unittest.skipUnless(False, "skipping: does not work yet")
- def testFlowLabel(self):
- s = net_test.IPv6PingSocket()
-
- # Specifying a flowlabel without having set IPV6_FLOWINFO_SEND succeeds but
- # the flow label in the packet is not set.
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 93, 0xdead, 0))
- self.assertValidPingResponse(s, net_test.IPV6_PING) # Checks flow label==0.
-
- # If IPV6_FLOWINFO_SEND is set on the socket, attempting to set a flow label
- # that is not registered with the flow manager should return EINVAL...
- s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
- # ... but this doesn't work yet.
- if False:
- self.assertRaisesErrno(errno.EINVAL, s.sendto, net_test.IPV6_PING,
- (net_test.IPV6_ADDR, 93, 0xdead, 0))
-
- # After registering the flow label, it gets sent properly, appears in the
- # output packet, and is returned in the response.
- net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
- self.assertEqual(1, s.getsockopt(net_test.SOL_IPV6,
- net_test.IPV6_FLOWINFO_SEND))
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 93, 0xdead, 0))
- _, src = s.recvfrom(32768)
- _, _, flowlabel, _ = src
- self.assertEqual(0xdead, flowlabel & 0xfffff)
-
- def testIPv4Error(self):
- s = net_test.IPv4PingSocket()
- s.setsockopt(SOL_IP, IP_TTL, 2)
- s.setsockopt(SOL_IP, net_test.IP_RECVERR, 1)
- s.sendto(net_test.IPV4_PING, (net_test.IPV4_ADDR, 55))
- # We can't check the actual error because Python 2.7 doesn't implement
- # recvmsg, but we can at least check that the socket returns an error.
- self.assertRaisesErrno(errno.EHOSTUNREACH, s.recv, 32768) # No response.
-
- def testIPv6Error(self):
- s = net_test.IPv6PingSocket()
- s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 2)
- s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55))
- # We can't check the actual error because Python 2.7 doesn't implement
- # recvmsg, but we can at least check that the socket returns an error.
- self.assertRaisesErrno(errno.EHOSTUNREACH, s.recv, 32768) # No response.
-
- def testIPv6MulticastPing(self):
- s = net_test.IPv6PingSocket()
- # Send a multicast ping and check we get at least one duplicate.
- # The setsockopt should not be necessary, but ping_v6_sendmsg has a bug.
- s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_MULTICAST_IF, self.ifindex)
- s.sendto(net_test.IPV6_PING, ("ff02::1", 55, 0, self.ifindex))
- self.assertValidPingResponse(s, net_test.IPV6_PING)
- self.assertValidPingResponse(s, net_test.IPV6_PING)
-
- def testIPv4LargePacket(self):
- s = net_test.IPv4PingSocket()
- data = net_test.IPV4_PING + 20000 * "a"
- s.sendto(data, ("127.0.0.1", 987))
- self.assertValidPingResponse(s, data)
-
- def testIPv6LargePacket(self):
- s = net_test.IPv6PingSocket()
- s.bind(("::", 0xace))
- data = net_test.IPV6_PING + "\x01" + 19994 * "\x00" + "aaaaa"
- s.sendto(data, ("::1", 953))
-
- @unittest.skipUnless(HAVE_PROC_NET_ICMP6, "skipping: no /proc/net/icmp6")
- def testIcmpSocketsNotInIcmp6(self):
- numrows = len(self.ReadProcNetSocket("icmp"))
- numrows6 = len(self.ReadProcNetSocket("icmp6"))
- s = net_test.Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)
- s.bind(("127.0.0.1", 0xace))
- s.connect(("127.0.0.1", 0xbeef))
- self.assertEquals(numrows + 1, len(self.ReadProcNetSocket("icmp")))
- self.assertEquals(numrows6, len(self.ReadProcNetSocket("icmp6")))
-
- @unittest.skipUnless(HAVE_PROC_NET_ICMP6, "skipping: no /proc/net/icmp6")
- def testIcmp6SocketsNotInIcmp(self):
- numrows = len(self.ReadProcNetSocket("icmp"))
- numrows6 = len(self.ReadProcNetSocket("icmp6"))
- s = net_test.IPv6PingSocket()
- s.bind(("::1", 0xace))
- s.connect(("::1", 0xbeef))
- self.assertEquals(numrows, len(self.ReadProcNetSocket("icmp")))
- self.assertEquals(numrows6 + 1, len(self.ReadProcNetSocket("icmp6")))
-
- def testProcNetIcmp(self):
- s = net_test.Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)
- s.bind(("127.0.0.1", 0xace))
- s.connect(("127.0.0.1", 0xbeef))
- self.CheckSockStatFile("icmp", "127.0.0.1", 0xace, "127.0.0.1", 0xbeef, 1)
-
- @unittest.skipUnless(HAVE_PROC_NET_ICMP6, "skipping: no /proc/net/icmp6")
- def testProcNetIcmp6(self):
- numrows6 = len(self.ReadProcNetSocket("icmp6"))
- s = net_test.IPv6PingSocket()
- s.bind(("::1", 0xace))
- s.connect(("::1", 0xbeef))
- self.CheckSockStatFile("icmp6", "::1", 0xace, "::1", 0xbeef, 1)
-
- # Check the row goes away when the socket is closed.
- s.close()
- self.assertEquals(numrows6, len(self.ReadProcNetSocket("icmp6")))
-
- # Try send, bind and connect to check the addresses and the state.
- s = net_test.IPv6PingSocket()
- self.assertEqual(0, len(self.ReadProcNetSocket("icmp6")))
- s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 12345))
- self.assertEqual(1, len(self.ReadProcNetSocket("icmp6")))
-
- # Can't bind after sendto, apparently.
- s = net_test.IPv6PingSocket()
- self.assertEqual(0, len(self.ReadProcNetSocket("icmp6")))
- s.bind((self.lladdr, 0xd00d, 0, self.ifindex))
- self.CheckSockStatFile("icmp6", self.lladdr, 0xd00d, "::", 0, 7)
-
- # Check receive bytes.
- s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_MULTICAST_IF, self.ifindex)
- s.connect(("ff02::1", 0xdead))
- self.CheckSockStatFile("icmp6", self.lladdr, 0xd00d, "ff02::1", 0xdead, 1)
- s.send(net_test.IPV6_PING)
- time.sleep(0.01) # Give the other thread time to reply.
- self.CheckSockStatFile("icmp6", self.lladdr, 0xd00d, "ff02::1", 0xdead, 1,
- txmem=0, rxmem=0x300)
- self.assertValidPingResponse(s, net_test.IPV6_PING)
- self.CheckSockStatFile("icmp6", self.lladdr, 0xd00d, "ff02::1", 0xdead, 1,
- txmem=0, rxmem=0)
-
- def testProcNetUdp6(self):
- s = net_test.Socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
- s.bind(("::1", 0xace))
- s.connect(("::1", 0xbeef))
- self.CheckSockStatFile("udp6", "::1", 0xace, "::1", 0xbeef, 1)
-
- def testProcNetRaw6(self):
- s = net_test.Socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)
- s.bind(("::1", 0xace))
- s.connect(("::1", 0xbeef))
- self.CheckSockStatFile("raw6", "::1", 0xff, "::1", 0, 1)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/ping6_test.sh b/tests/net_test/ping6_test.sh
deleted file mode 100755
index 41dabce..0000000
--- a/tests/net_test/ping6_test.sh
+++ /dev/null
@@ -1,16 +0,0 @@
-#!/bin/bash
-
-# Minimal network initialization.
-ip link set eth0 up
-
-# Wait for autoconf and DAD to complete.
-sleep 3 &
-
-# Block on starting DHCPv4.
-udhcpc -i eth0
-
-# If DHCPv4 took less than 3 seconds, keep waiting.
-wait
-
-# Run the test.
-$(dirname $0)/ping6_test.py
diff --git a/tests/net_test/run_net_test.sh b/tests/net_test/run_net_test.sh
deleted file mode 100755
index 080aac7..0000000
--- a/tests/net_test/run_net_test.sh
+++ /dev/null
@@ -1,125 +0,0 @@
-#!/bin/bash
-
-# Kernel configuration options.
-OPTIONS=" DEBUG_SPINLOCK DEBUG_ATOMIC_SLEEP DEBUG_MUTEXES DEBUG_RT_MUTEXES"
-OPTIONS="$OPTIONS IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
-OPTIONS="$OPTIONS TUN SYN_COOKIES IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
-OPTIONS="$OPTIONS NETFILTER NETFILTER_ADVANCED NETFILTER_XTABLES"
-OPTIONS="$OPTIONS NETFILTER_XT_MARK NETFILTER_XT_TARGET_MARK"
-OPTIONS="$OPTIONS IP_NF_IPTABLES IP_NF_MANGLE"
-OPTIONS="$OPTIONS IP6_NF_IPTABLES IP6_NF_MANGLE INET6_IPCOMP"
-OPTIONS="$OPTIONS IPV6_PRIVACY IPV6_OPTIMISTIC_DAD"
-OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_TARGET_NFLOG"
-OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA CONFIG_NETFILTER_XT_MATCH_QUOTA2"
-OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG"
-OPTIONS="$OPTIONS CONFIG_INET_UDP_DIAG CONFIG_INET_DIAG_DESTROY"
-
-# For 3.1 kernels, where devtmpfs is not on by default.
-OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT"
-
-# These two break the flo kernel due to differences in -Werror on recent GCC.
-DISABLE_OPTIONS=" CONFIG_REISERFS_FS CONFIG_ANDROID_PMEM"
-
-# How many TAP interfaces to create to provide the VM with real network access
-# via the host. This requires privileges (e.g., root access) on the host.
-#
-# This is not needed to run the tests, but can be used, for example, to allow
-# the VM to update system packages, or to write tests that need access to a
-# real network. The VM does not set up networking by default, but it contains a
-# DHCP client and has the ability to use IPv6 autoconfiguration. This script
-# does not perform any host-level setup beyond configuring tap interfaces;
-# configuring IPv4 NAT and/or IPv6 router advertisements or ND proxying must
-# be done separately.
-NUMTAPINTERFACES=0
-
-# The root filesystem disk image we'll use.
-ROOTFS=net_test.rootfs.20150203
-COMPRESSED_ROOTFS=$ROOTFS.xz
-URL=https://dl.google.com/dl/android/$COMPRESSED_ROOTFS
-
-# Figure out which test to run.
-if [ -z "$1" ]; then
- echo "Usage: $0 <test>" >&2
- exit 1
-fi
-test=$1
-
-set -e
-
-# Check if we need to uncompress the disk image.
-# We use xz because it compresses better: to 42M vs 72M (gzip) / 62M (bzip2).
-cd $(dirname $0)
-if [ ! -f $ROOTFS ]; then
- echo "Deleting $COMPRESSED_ROOTFS" >&2
- rm -f $COMPRESSED_ROOTFS
- echo "Downloading $URL" >&2
- wget $URL
- echo "Uncompressing $COMPRESSED_ROOTFS" >&2
- unxz $COMPRESSED_ROOTFS
-fi
-echo "Using $ROOTFS"
-cd -
-
-# If network access was requested, create NUMTAPINTERFACES tap interfaces on
-# the host, and prepare UML command line params to use them. The interfaces are
-# called <user>TAP0, <user>TAP1, on the host, and eth0, eth1, ..., in the VM.
-if (( $NUMTAPINTERFACES > 0 )); then
- user=${USER:0:10}
- tapinterfaces=
- netconfig=
- for id in $(seq 0 $(( NUMTAPINTERFACES - 1 )) ); do
- tap=${user}TAP$id
- tapinterfaces="$tapinterfaces $tap"
- mac=$(printf fe:fd:00:00:00:%02x $id)
- netconfig="$netconfig eth$id=tuntap,$tap,$mac"
- done
-
- for tap in $tapinterfaces; do
- if ! ip link list $tap > /dev/null; then
- echo "Creating tap interface $tap" >&2
- sudo tunctl -u $USER -t $tap
- sudo ip link set $tap up
- fi
- done
-fi
-
-if [ -z "$KERNEL_BINARY" ]; then
- # Exporting ARCH=um SUBARCH=x86_64 doesn't seem to work, as it "sometimes"
- # (?) results in a 32-bit kernel.
-
- # If there's no kernel config at all, create one or UML won't work.
- [ -f .config ] || make defconfig ARCH=um SUBARCH=x86_64
-
- # Enable the kernel config options listed in $OPTIONS.
- cmdline=${OPTIONS// / -e }
- ./scripts/config $cmdline
-
- # Disable the kernel config options listed in $DISABLE_OPTIONS.
- cmdline=${DISABLE_OPTIONS// / -d }
- ./scripts/config $cmdline
-
- # olddefconfig doesn't work on old kernels.
- if ! make olddefconfig ARCH=um SUBARCH=x86_64 CROSS_COMPILE= ; then
- cat >&2 << EOF
-
-Warning: "make olddefconfig" failed.
-Perhaps this kernel is too old to support it.
-You may get asked lots of questions.
-Keep enter pressed to accept the defaults.
-
-EOF
- fi
-
- # Compile the kernel.
- make -j32 linux ARCH=um SUBARCH=x86_64 CROSS_COMPILE=
- KERNEL_BINARY=./linux
-fi
-
-
-# Get the absolute path to the test file that's being run.
-dir=/host$(dirname $(readlink -f $0))
-
-# Start the VM.
-exec $KERNEL_BINARY umid=net_test ubda=$(dirname $0)/$ROOTFS \
- mem=512M init=/sbin/net_test.sh net_test=$dir/$test \
- $netconfig
diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py
deleted file mode 100755
index b4d9cf6..0000000
--- a/tests/net_test/sock_diag.py
+++ /dev/null
@@ -1,342 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2015 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Partial Python implementation of sock_diag functionality."""
-
-# pylint: disable=g-bad-todo
-
-import errno
-from socket import * # pylint: disable=wildcard-import
-import struct
-
-import cstruct
-import net_test
-import netlink
-
-### Base netlink constants. See include/uapi/linux/netlink.h.
-NETLINK_SOCK_DIAG = 4
-
-### sock_diag constants. See include/uapi/linux/sock_diag.h.
-# Message types.
-SOCK_DIAG_BY_FAMILY = 20
-SOCK_DESTROY = 21
-
-### inet_diag_constants. See include/uapi/linux/inet_diag.h
-# Message types.
-TCPDIAG_GETSOCK = 18
-
-# Request attributes.
-INET_DIAG_REQ_BYTECODE = 1
-
-# Extensions.
-INET_DIAG_NONE = 0
-INET_DIAG_MEMINFO = 1
-INET_DIAG_INFO = 2
-INET_DIAG_VEGASINFO = 3
-INET_DIAG_CONG = 4
-INET_DIAG_TOS = 5
-INET_DIAG_TCLASS = 6
-INET_DIAG_SKMEMINFO = 7
-INET_DIAG_SHUTDOWN = 8
-INET_DIAG_DCTCPINFO = 9
-
-# Bytecode operations.
-INET_DIAG_BC_NOP = 0
-INET_DIAG_BC_JMP = 1
-INET_DIAG_BC_S_GE = 2
-INET_DIAG_BC_S_LE = 3
-INET_DIAG_BC_D_GE = 4
-INET_DIAG_BC_D_LE = 5
-INET_DIAG_BC_AUTO = 6
-INET_DIAG_BC_S_COND = 7
-INET_DIAG_BC_D_COND = 8
-
-# Data structure formats.
-# These aren't constants, they're classes. So, pylint: disable=invalid-name
-InetDiagSockId = cstruct.Struct(
- "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie")
-InetDiagReqV2 = cstruct.Struct(
- "InetDiagReqV2", "=BBBxIS", "family protocol ext states id",
- [InetDiagSockId])
-InetDiagMsg = cstruct.Struct(
- "InetDiagMsg", "=BBBBSLLLLL",
- "family state timer retrans id expires rqueue wqueue uid inode",
- [InetDiagSockId])
-InetDiagMeminfo = cstruct.Struct(
- "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
-InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
-InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
- "family prefix_len port")
-
-SkMeminfo = cstruct.Struct(
- "SkMeminfo", "=IIIIIIII",
- "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog")
-TcpInfo = cstruct.Struct(
- "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII",
- "state ca_state retransmits probes backoff options wscale "
- "rto ato snd_mss rcv_mss "
- "unacked sacked lost retrans fackets "
- "last_data_sent last_ack_sent last_data_recv last_ack_recv "
- "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering "
- "rcv_rtt rcv_space "
- "total_retrans") # As of linux 3.13, at least.
-
-TCP_TIME_WAIT = 6
-ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT)
-
-
-class SockDiag(netlink.NetlinkSocket):
-
- FAMILY = NETLINK_SOCK_DIAG
- NL_DEBUG = []
-
- def _Decode(self, command, msg, nla_type, nla_data):
- """Decodes netlink attributes to Python types."""
- if msg.family == AF_INET or msg.family == AF_INET6:
- name = self._GetConstantName(__name__, nla_type, "INET_DIAG")
- else:
- # Don't know what this is. Leave it as an integer.
- name = nla_type
-
- if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS"]:
- data = ord(nla_data)
- elif name == "INET_DIAG_CONG":
- data = nla_data.strip("\x00")
- elif name == "INET_DIAG_MEMINFO":
- data = InetDiagMeminfo(nla_data)
- elif name == "INET_DIAG_INFO":
- # TODO: Catch the exception and try something else if it's not TCP.
- data = TcpInfo(nla_data)
- elif name == "INET_DIAG_SKMEMINFO":
- data = SkMeminfo(nla_data)
- else:
- data = nla_data
-
- return name, data
-
- def MaybeDebugCommand(self, command, data):
- name = self._GetConstantName(__name__, command, "SOCK_")
- if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG:
- return
- parsed = self._ParseNLMsg(data, InetDiagReqV2)
- print "%s %s" % (name, str(parsed))
-
- @staticmethod
- def _EmptyInetDiagSockId():
- return InetDiagSockId(("\x00" * len(InetDiagSockId)))
-
- def PackBytecode(self, instructions):
- """Compiles instructions to inet_diag bytecode.
-
- The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
- and no are relative jump offsets measured in instructions. The yes branch
- is taken if the instruction matches.
-
- To accept, jump 1 past the last instruction. To reject, jump 2 past the
- last instruction.
-
- The target of a no jump is only valid if it is reachable by following
- only yes jumps from the first instruction - see inet_diag_bc_audit and
- valid_cc. This means that if cond1 and cond2 are two mutually exclusive
- filter terms, it is not possible to implement cond1 OR cond2 using:
-
- ...
- cond1 2 1 arg
- cond2 1 2 arg
- accept
- reject
-
- but only using:
-
- ...
- cond1 1 2 arg
- jmp 1 2
- cond2 1 2 arg
- accept
- reject
-
- The jmp instruction ignores yes and always jumps to no, but yes must be 1
- or the bytecode won't validate. It doesn't have to be jmp - any instruction
- that is guaranteed not to match on real data will do.
-
- Args:
- instructions: list of instruction tuples
-
- Returns:
- A string, the raw bytecode.
- """
- args = []
- positions = [0]
-
- for op, yes, no, arg in instructions:
-
- if yes <= 0 or no <= 0:
- raise ValueError("Jumps must be > 0")
-
- if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
- arg = ""
- elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
- INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
- arg = "\x00\x00" + struct.pack("=H", arg)
- elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
- addr, prefixlen, port = arg
- family = AF_INET6 if ":" in addr else AF_INET
- addr = inet_pton(family, addr)
- arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
- else:
- raise ValueError("Unsupported opcode %d" % op)
-
- args.append(arg)
- length = len(InetDiagBcOp) + len(arg)
- positions.append(positions[-1] + length)
-
- # Reject label.
- positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4.
- assert len(args) == len(instructions) == len(positions) - 2
-
- # print positions
-
- packed = ""
- for i, (op, yes, no, arg) in enumerate(instructions):
- yes = positions[i + yes] - positions[i]
- no = positions[i + no] - positions[i]
- instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
- #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
- # arg, instruction.encode("hex"))
- packed += instruction
- #print
-
- return packed
-
- def Dump(self, diag_req, bytecode=""):
- out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
- return out
-
- def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
- states=ALL_NON_TIME_WAIT):
- """Dumps IPv4 or IPv6 sockets matching the specified parameters."""
- # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
- # results in ENOENT.
- if sock_id is None:
- sock_id = self._EmptyInetDiagSockId()
-
- if bytecode:
- bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
-
- sockets = []
- for family in [AF_INET, AF_INET6]:
- diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
- sockets += self.Dump(diag_req, bytecode)
-
- return sockets
-
- @staticmethod
- def GetRawAddress(family, addr):
- """Fetches the source address from an InetDiagMsg."""
- addrlen = {AF_INET:4, AF_INET6: 16}[family]
- return inet_ntop(family, addr[:addrlen])
-
- @staticmethod
- def GetSourceAddress(diag_msg):
- """Fetches the source address from an InetDiagMsg."""
- return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src)
-
- @staticmethod
- def GetDestinationAddress(diag_msg):
- """Fetches the source address from an InetDiagMsg."""
- return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst)
-
- @staticmethod
- def RawAddress(addr):
- """Converts an IP address string to binary format."""
- family = AF_INET6 if ":" in addr else AF_INET
- return inet_pton(family, addr)
-
- @staticmethod
- def PaddedAddress(addr):
- """Converts an IP address string to binary format for InetDiagSockId."""
- padded = SockDiag.RawAddress(addr)
- if len(padded) < 16:
- padded += "\x00" * (16 - len(padded))
- return padded
-
- @staticmethod
- def DiagReqFromSocket(s):
- """Creates an InetDiagReqV2 that matches the specified socket."""
- family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
- protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
- if net_test.LINUX_VERSION >= (3, 8):
- iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE,
- net_test.IFNAMSIZ)
- iface = GetInterfaceIndex(iface) if iface else 0
- else:
- iface = 0
- src, sport = s.getsockname()[:2]
- try:
- dst, dport = s.getpeername()[:2]
- except error, e:
- if e.errno == errno.ENOTCONN:
- dport = 0
- dst = "::" if family == AF_INET6 else "0.0.0.0"
- else:
- raise e
- src = SockDiag.PaddedAddress(src)
- dst = SockDiag.PaddedAddress(dst)
- sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
- return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
-
- def FindSockDiagFromReq(self, req):
- for diag_msg, attrs in self.Dump(req):
- return diag_msg
- raise ValueError("Dump of %s returned no sockets" % req)
-
- def FindSockDiagFromFd(self, s):
- """Gets an InetDiagMsg from the kernel for the specified socket."""
- req = self.DiagReqFromSocket(s)
- return self.FindSockDiagFromReq(req)
-
- def GetSockDiag(self, req):
- """Gets an InetDiagMsg from the kernel for the specified request."""
- self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
- return self._GetMsg(InetDiagMsg)[0]
-
- @staticmethod
- def DiagReqFromDiagMsg(d, protocol):
- """Constructs a diag_req from a diag_msg the kernel has given us."""
- return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
-
- def CloseSocket(self, req):
- self._SendNlRequest(SOCK_DESTROY, req.Pack(),
- netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
-
- def CloseSocketFromFd(self, s):
- diag_msg = self.FindSockDiagFromFd(s)
- protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
- req = self.DiagReqFromDiagMsg(diag_msg, protocol)
- return self.CloseSocket(req)
-
-
-if __name__ == "__main__":
- n = SockDiag()
- n.DEBUG = True
- bytecode = ""
- sock_id = n._EmptyInetDiagSockId()
- sock_id.dport = 443
- ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
- states = 0xffffffff
- diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "",
- sock_id=sock_id, ext=ext, states=states)
- print diag_msgs
diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py
deleted file mode 100755
index 3c5d0a9..0000000
--- a/tests/net_test/sock_diag_test.py
+++ /dev/null
@@ -1,548 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2015 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
-from errno import * # pylint: disable=wildcard-import
-import os
-import random
-import re
-from socket import * # pylint: disable=wildcard-import
-import threading
-import time
-import unittest
-
-import multinetwork_base
-import net_test
-import packets
-import sock_diag
-import tcp_test
-
-
-NUM_SOCKETS = 30
-NO_BYTECODE = ""
-
-
-class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
-
- @staticmethod
- def _CreateLotsOfSockets():
- # Dict mapping (addr, sport, dport) tuples to socketpairs.
- socketpairs = {}
- for _ in xrange(NUM_SOCKETS):
- family, addr = random.choice([
- (AF_INET, "127.0.0.1"),
- (AF_INET6, "::1"),
- (AF_INET6, "::ffff:127.0.0.1")])
- socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
- sport, dport = (socketpair[0].getsockname()[1],
- socketpair[1].getsockname()[1])
- socketpairs[(addr, sport, dport)] = socketpair
- return socketpairs
-
- def assertSocketClosed(self, sock):
- self.assertRaisesErrno(ENOTCONN, sock.getpeername)
-
- def assertSocketConnected(self, sock):
- sock.getpeername() # No errors? Socket is alive and connected.
-
- def assertSocketsClosed(self, socketpair):
- for sock in socketpair:
- self.assertSocketClosed(sock)
-
- def setUp(self):
- super(SockDiagBaseTest, self).setUp()
- self.sock_diag = sock_diag.SockDiag()
- self.socketpairs = {}
-
- def tearDown(self):
- for socketpair in self.socketpairs.values():
- for s in socketpair:
- s.close()
- super(SockDiagBaseTest, self).tearDown()
-
-
-class SockDiagTest(SockDiagBaseTest):
-
- def assertSockDiagMatchesSocket(self, s, diag_msg):
- family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
- self.assertEqual(diag_msg.family, family)
-
- src, sport = s.getsockname()[0:2]
- self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
- self.assertEqual(diag_msg.id.sport, sport)
-
- if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
- dst, dport = s.getpeername()[0:2]
- self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
- self.assertEqual(diag_msg.id.dport, dport)
- else:
- self.assertRaisesErrno(ENOTCONN, s.getpeername)
-
- def testFindsMappedSockets(self):
- """Tests that inet_diag_find_one_icsk can find mapped sockets.
-
- Relevant kernel commits:
- android-3.10:
- f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
- """
- socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
- "::ffff:127.0.0.1")
- for sock in socketpair:
- diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
- diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
- self.sock_diag.GetSockDiag(diag_req)
- # No errors? Good.
-
- def testFindsAllMySockets(self):
- """Tests that basic socket dumping works.
-
- Relevant commits:
- android-3.4:
- ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
- android-3.10
- 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
- """
- self.socketpairs = self._CreateLotsOfSockets()
- sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
- self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
-
- # Find the cookies for all of our sockets.
- cookies = {}
- for diag_msg, unused_attrs in sockets:
- addr = self.sock_diag.GetSourceAddress(diag_msg)
- sport = diag_msg.id.sport
- dport = diag_msg.id.dport
- if (addr, sport, dport) in self.socketpairs:
- cookies[(addr, sport, dport)] = diag_msg.id.cookie
- elif (addr, dport, sport) in self.socketpairs:
- cookies[(addr, sport, dport)] = diag_msg.id.cookie
-
- # Did we find all the cookies?
- self.assertEquals(2 * NUM_SOCKETS, len(cookies))
-
- socketpairs = self.socketpairs.values()
- random.shuffle(socketpairs)
- for socketpair in socketpairs:
- for sock in socketpair:
- # Check that we can find a diag_msg by scanning a dump.
- self.assertSockDiagMatchesSocket(
- sock,
- self.sock_diag.FindSockDiagFromFd(sock))
- cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
-
- # Check that we can find a diag_msg once we know the cookie.
- req = self.sock_diag.DiagReqFromSocket(sock)
- req.id.cookie = cookie
- diag_msg = self.sock_diag.GetSockDiag(req)
- req.states = 1 << diag_msg.state
- self.assertSockDiagMatchesSocket(sock, diag_msg)
-
- def testBytecodeCompilation(self):
- # pylint: disable=bad-whitespace
- instructions = [
- (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0
- (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8
- (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16
- (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44
- (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48
- (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64
- (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72
- # 76 acc
- # 80 rej
- ]
- # pylint: enable=bad-whitespace
- bytecode = self.sock_diag.PackBytecode(instructions)
- expected = (
- "0208500000000000"
- "050848000000ffff"
- "071c20000a800000ffffffff00000000000000000000000000000001"
- "01041c00"
- "0718200002200000ffffffff7f000001"
- "0508100000006566"
- "00040400"
- )
- self.assertMultiLineEqual(expected, bytecode.encode("hex"))
- self.assertEquals(76, len(bytecode))
- self.socketpairs = self._CreateLotsOfSockets()
- filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
- allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
- self.assertItemsEqual(allsockets, filteredsockets)
-
- # Pick a few sockets in hash table order, and check that the bytecode we
- # compiled selects them properly.
- for socketpair in self.socketpairs.values()[:20]:
- for s in socketpair:
- diag_msg = self.sock_diag.FindSockDiagFromFd(s)
- instructions = [
- (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
- (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
- (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
- (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
- ]
- bytecode = self.sock_diag.PackBytecode(instructions)
- self.assertEquals(32, len(bytecode))
- sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
- self.assertEquals(1, len(sockets))
-
- # TODO: why doesn't comparing the cstructs work?
- self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
-
- def testCrossFamilyBytecode(self):
- """Checks for a cross-family bug in inet_diag_hostcond matching.
-
- Relevant kernel commits:
- android-3.4:
- f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
- """
- # TODO: this is only here because the test fails if there are any open
- # sockets other than the ones it creates itself. Make the bytecode more
- # specific and remove it.
- self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, ""))
-
- unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
- unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
-
- bytecode4 = self.sock_diag.PackBytecode([
- (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
- bytecode6 = self.sock_diag.PackBytecode([
- (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
-
- # IPv4/v6 filters must never match IPv6/IPv4 sockets...
- v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)
- self.assertTrue(v4sockets)
- self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets))
-
- v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)
- self.assertTrue(v6sockets)
- self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets))
-
- # Except for mapped addresses, which match both IPv4 and IPv6.
- pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
- "::ffff:127.0.0.1")
- diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
- v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
- bytecode4)]
- v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
- bytecode6)]
- self.assertTrue(all(d in v4sockets for d in diag_msgs))
- self.assertTrue(all(d in v6sockets for d in diag_msgs))
-
- def testPortComparisonValidation(self):
- """Checks for a bug in validating port comparison bytecode.
-
- Relevant kernel commits:
- android-3.4:
- 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
- """
- bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
- self.assertRaisesErrno(
- EINVAL,
- self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
-
- def testNonSockDiagCommand(self):
- def DiagDump(code):
- sock_id = self.sock_diag._EmptyInetDiagSockId()
- req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
- sock_id))
- self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
-
- op = sock_diag.SOCK_DIAG_BY_FAMILY
- DiagDump(op) # No errors? Good.
- self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
-
-
-class SockDestroyTest(SockDiagBaseTest):
- """Tests that SOCK_DESTROY works correctly.
-
- Relevant kernel commits:
- net-next:
- b613f56 net: diag: split inet_diag_dump_one_icsk into two
- 64be0ae net: diag: Add the ability to destroy a socket.
- 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
- c1e64e2 net: diag: Support destroying TCP sockets.
- 2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
-
- android-3.4:
- d48ec88 net: diag: split inet_diag_dump_one_icsk into two
- 2438189 net: diag: Add the ability to destroy a socket.
- 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
- 44047b2 net: diag: Support destroying TCP sockets.
- 200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
-
- android-3.10:
- 9eaff90 net: diag: split inet_diag_dump_one_icsk into two
- d60326c net: diag: Add the ability to destroy a socket.
- 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
- 529dfc6 net: diag: Support destroying TCP sockets.
- 9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
-
- android-3.18:
- 100263d net: diag: split inet_diag_dump_one_icsk into two
- 194c5f3 net: diag: Add the ability to destroy a socket.
- 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
- b80585a net: diag: Support destroying TCP sockets.
- 476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
- """
-
- def testClosesSockets(self):
- self.socketpairs = self._CreateLotsOfSockets()
- for _, socketpair in self.socketpairs.iteritems():
- # Close one of the sockets.
- # This will send a RST that will close the other side as well.
- s = random.choice(socketpair)
- if random.randrange(0, 2) == 1:
- self.sock_diag.CloseSocketFromFd(s)
- else:
- diag_msg = self.sock_diag.FindSockDiagFromFd(s)
-
- # Get the cookie wrong and ensure that we get an error and the socket
- # is not closed.
- real_cookie = diag_msg.id.cookie
- diag_msg.id.cookie = os.urandom(len(real_cookie))
- req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
- self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
- self.assertSocketConnected(s)
-
- # Now close it with the correct cookie.
- req.id.cookie = real_cookie
- self.sock_diag.CloseSocket(req)
-
- # Check that both sockets in the pair are closed.
- self.assertSocketsClosed(socketpair)
-
- def testNonTcpSockets(self):
- s = socket(AF_INET6, SOCK_DGRAM, 0)
- s.connect(("::1", 53))
- self.sock_diag.FindSockDiagFromFd(s) # No exceptions? Good.
- self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
-
- # TODO:
- # Test that killing unix sockets returns EOPNOTSUPP.
-
-
-class SocketExceptionThread(threading.Thread):
-
- def __init__(self, sock, operation):
- self.exception = None
- super(SocketExceptionThread, self).__init__()
- self.daemon = True
- self.sock = sock
- self.operation = operation
-
- def run(self):
- try:
- self.operation(self.sock)
- except IOError, e:
- self.exception = e
-
-
-class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
-
- def testIpv4MappedSynRecvSocket(self):
- """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
-
- Relevant kernel commits:
- android-3.4:
- 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
- """
- netid = random.choice(self.tuns.keys())
- self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
- sock_id = self.sock_diag._EmptyInetDiagSockId()
- sock_id.sport = self.port
- states = 1 << tcp_test.TCP_SYN_RECV
- req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
- children = self.sock_diag.Dump(req, NO_BYTECODE)
-
- self.assertTrue(children)
- for child, unused_args in children:
- self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
- self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
- child.id.dst)
- self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
- child.id.src)
-
-
-class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
-
- def setUp(self):
- super(SockDestroyTcpTest, self).setUp()
- self.netid = random.choice(self.tuns.keys())
-
- def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
- """Closes the socket and checks whether a RST is sent or not."""
- if sock is not None:
- self.assertIsNone(req, "Must specify sock or req, not both")
- self.sock_diag.CloseSocketFromFd(sock)
- self.assertRaisesErrno(EINVAL, sock.accept)
- else:
- self.assertIsNone(sock, "Must specify sock or req, not both")
- self.sock_diag.CloseSocket(req)
-
- if expect_reset:
- desc, rst = self.RstPacket()
- msg = "%s: expecting %s: " % (msg, desc)
- self.ExpectPacketOn(self.netid, msg, rst)
- else:
- msg = "%s: " % msg
- self.ExpectNoPacketsOn(self.netid, msg)
-
- if sock is not None and do_close:
- sock.close()
-
- def CheckTcpReset(self, state, statename):
- for version in [4, 5, 6]:
- msg = "Closing incoming IPv%d %s socket" % (version, statename)
- self.IncomingConnection(version, state, self.netid)
- self.CheckRstOnClose(self.s, None, False, msg)
- if state != tcp_test.TCP_LISTEN:
- msg = "Closing accepted IPv%d %s socket" % (version, statename)
- self.CheckRstOnClose(self.accepted, None, True, msg)
-
- def testTcpResets(self):
- """Checks that closing sockets in appropriate states sends a RST."""
- self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
- self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
- self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
-
- def FindChildSockets(self, s):
- """Finds the SYN_RECV child sockets of a given listening socket."""
- d = self.sock_diag.FindSockDiagFromFd(self.s)
- req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
- req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
- req.id.cookie = "\x00" * 8
- children = self.sock_diag.Dump(req, NO_BYTECODE)
- return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
- for d, _ in children]
-
- def CheckChildSocket(self, version, statename, parent_first):
- state = getattr(tcp_test, statename)
-
- self.IncomingConnection(version, state, self.netid)
-
- d = self.sock_diag.FindSockDiagFromFd(self.s)
- parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
- children = self.FindChildSockets(self.s)
- self.assertEquals(1, len(children))
-
- is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
-
- # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
- # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
- # Before 4.4, we can see those sockets in dumps, but we can't fetch
- # or close them.
- can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
-
- for child in children:
- if can_close_children:
- self.sock_diag.GetSockDiag(child) # No errors? Good, child found.
- else:
- self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
-
- def CloseParent(expect_reset):
- msg = "Closing parent IPv%d %s socket %s child" % (
- version, statename, "before" if parent_first else "after")
- self.CheckRstOnClose(self.s, None, expect_reset, msg)
- self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
-
- def CheckChildrenClosed():
- for child in children:
- self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
-
- def CloseChildren():
- for child in children:
- msg = "Closing child IPv%d %s socket %s parent" % (
- version, statename, "after" if parent_first else "before")
- self.sock_diag.GetSockDiag(child)
- self.CheckRstOnClose(None, child, is_established, msg)
- self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
- CheckChildrenClosed()
-
- if parent_first:
- # Closing the parent will close child sockets, which will send a RST,
- # iff they are already established.
- CloseParent(is_established)
- if is_established:
- CheckChildrenClosed()
- elif can_close_children:
- CloseChildren()
- CheckChildrenClosed()
- self.s.close()
- else:
- if can_close_children:
- CloseChildren()
- CloseParent(False)
- self.s.close()
-
- def testChildSockets(self):
- for version in [4, 5, 6]:
- self.CheckChildSocket(version, "TCP_SYN_RECV", False)
- self.CheckChildSocket(version, "TCP_SYN_RECV", True)
- self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
- self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
-
- def CloseDuringBlockingCall(self, sock, call, expected_errno):
- thread = SocketExceptionThread(sock, call)
- thread.start()
- time.sleep(0.1)
- self.sock_diag.CloseSocketFromFd(sock)
- thread.join(1)
- self.assertFalse(thread.is_alive())
- self.assertIsNotNone(thread.exception)
- self.assertTrue(isinstance(thread.exception, IOError),
- "Expected IOError, got %s" % thread.exception)
- self.assertEqual(expected_errno, thread.exception.errno)
- self.assertSocketClosed(sock)
-
- def testAcceptInterrupted(self):
- """Tests that accept() is interrupted by SOCK_DESTROY."""
- for version in [4, 5, 6]:
- self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
- self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
- self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
- self.assertRaisesErrno(EINVAL, self.s.accept)
-
- def testReadInterrupted(self):
- """Tests that read() is interrupted by SOCK_DESTROY."""
- for version in [4, 5, 6]:
- self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
- self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
- ECONNABORTED)
- self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
-
- def testConnectInterrupted(self):
- """Tests that connect() is interrupted by SOCK_DESTROY."""
- for version in [4, 5, 6]:
- family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
- s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
- self.SelectInterface(s, self.netid, "mark")
- if version == 5:
- remoteaddr = "::ffff:" + self.GetRemoteAddress(4)
- version = 4
- else:
- remoteaddr = self.GetRemoteAddress(version)
- s.bind(("", 0))
- _, sport = s.getsockname()[:2]
- self.CloseDuringBlockingCall(
- s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
- desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
- remoteaddr, sport=sport, seq=None)
- self.ExpectPacketOn(self.netid, desc, syn)
- msg = "SOCK_DESTROY of socket in connect, expected no RST"
- self.ExpectNoPacketsOn(self.netid, msg)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/srcaddr_selection_test.py b/tests/net_test/srcaddr_selection_test.py
deleted file mode 100755
index d3efdd9..0000000
--- a/tests/net_test/srcaddr_selection_test.py
+++ /dev/null
@@ -1,345 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2014 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import errno
-import random
-from socket import * # pylint: disable=wildcard-import
-import time
-import unittest
-
-from scapy import all as scapy
-
-import csocket
-import iproute
-import multinetwork_base
-import packets
-import net_test
-
-# Setsockopt values.
-IPV6_ADDR_PREFERENCES = 72
-IPV6_PREFER_SRC_PUBLIC = 0x0002
-
-
-class IPv6SourceAddressSelectionTest(multinetwork_base.MultiNetworkBaseTest):
- """Test for IPv6 source address selection.
-
- Relevant kernel commits:
- upstream net-next:
- 7fd2561 net: ipv6: Add a sysctl to make optimistic addresses useful candidates
- c58da4c net: ipv6: allow explicitly choosing optimistic addresses
- 9131f3d ipv6: Do not iterate over all interfaces when finding source address on specific interface.
- c0b8da1 ipv6: Fix finding best source address in ipv6_dev_get_saddr().
- c15df30 ipv6: Remove unused arguments for __ipv6_dev_get_saddr().
- 3985e8a ipv6: sysctl to restrict candidate source addresses
-
- android-3.10:
- 2ce95507 net: ipv6: Add a sysctl to make optimistic addresses useful candidates
- 0065bf4 net: ipv6: allow choosing optimistic addresses with use_optimistic
- 0633924 ipv6: sysctl to restrict candidate source addresses
- """
-
- def SetIPv6Sysctl(self, ifname, sysctl, value):
- self.SetSysctl("/proc/sys/net/ipv6/conf/%s/%s" % (ifname, sysctl), value)
-
- def SetDAD(self, ifname, value):
- self.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % ifname, value)
- self.SetSysctl("/proc/sys/net/ipv6/conf/%s/dad_transmits" % ifname, value)
-
- def SetOptimisticDAD(self, ifname, value):
- self.SetSysctl("/proc/sys/net/ipv6/conf/%s/optimistic_dad" % ifname, value)
-
- def SetUseTempaddrs(self, ifname, value):
- self.SetSysctl("/proc/sys/net/ipv6/conf/%s/use_tempaddr" % ifname, value)
-
- def SetUseOptimistic(self, ifname, value):
- self.SetSysctl("/proc/sys/net/ipv6/conf/%s/use_optimistic" % ifname, value)
-
- def GetSourceIP(self, netid, mode="mark"):
- s = self.BuildSocket(6, net_test.UDPSocket, netid, mode)
- # Because why not...testing for temporary addresses is a separate thing.
- s.setsockopt(IPPROTO_IPV6, IPV6_ADDR_PREFERENCES, IPV6_PREFER_SRC_PUBLIC)
-
- s.connect((net_test.IPV6_ADDR, 123))
- src_addr = s.getsockname()[0]
- self.assertTrue(src_addr)
- return src_addr
-
- def assertAddressNotPresent(self, address):
- self.assertRaises(IOError, self.iproute.GetAddress, address)
-
- def assertAddressHasExpectedAttributes(
- self, address, expected_ifindex, expected_flags):
- ifa_msg = self.iproute.GetAddress(address)[0]
- self.assertEquals(AF_INET6 if ":" in address else AF_INET, ifa_msg.family)
- self.assertEquals(64, ifa_msg.prefixlen)
- self.assertEquals(iproute.RT_SCOPE_UNIVERSE, ifa_msg.scope)
- self.assertEquals(expected_ifindex, ifa_msg.index)
- self.assertEquals(expected_flags, ifa_msg.flags & expected_flags)
-
- def AddressIsTentative(self, address):
- ifa_msg = self.iproute.GetAddress(address)[0]
- return ifa_msg.flags & iproute.IFA_F_TENTATIVE
-
- def BindToAddress(self, address):
- s = net_test.UDPSocket(AF_INET6)
- s.bind((address, 0, 0, 0))
-
- def SendWithSourceAddress(self, address, netid, dest=net_test.IPV6_ADDR):
- pktinfo = multinetwork_base.MakePktInfo(6, address, 0)
- cmsgs = [(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)]
- s = self.BuildSocket(6, net_test.UDPSocket, netid, "mark")
- return csocket.Sendmsg(s, (dest, 53), "Hello", cmsgs, 0)
-
- def assertAddressUsable(self, address, netid):
- self.BindToAddress(address)
- self.SendWithSourceAddress(address, netid)
- # No exceptions? Good.
-
- def assertAddressNotUsable(self, address, netid):
- self.assertRaisesErrno(errno.EADDRNOTAVAIL, self.BindToAddress, address)
- self.assertRaisesErrno(errno.EINVAL,
- self.SendWithSourceAddress, address, netid)
-
- def assertAddressSelected(self, address, netid):
- self.assertEquals(address, self.GetSourceIP(netid))
-
- def assertAddressNotSelected(self, address, netid):
- self.assertNotEquals(address, self.GetSourceIP(netid))
-
- def WaitForDad(self, address):
- for _ in xrange(20):
- if not self.AddressIsTentative(address):
- return
- time.sleep(0.1)
- raise AssertionError("%s did not complete DAD after 2 seconds")
-
-
-class MultiInterfaceSourceAddressSelectionTest(IPv6SourceAddressSelectionTest):
-
- def setUp(self):
- # [0] Make sure DAD, optimistic DAD, and the use_optimistic option
- # are all consistently disabled at the outset.
- for netid in self.tuns:
- ifname = self.GetInterfaceName(netid)
- self.SetDAD(ifname, 0)
- self.SetOptimisticDAD(ifname, 0)
- self.SetUseTempaddrs(ifname, 0)
- self.SetUseOptimistic(ifname, 0)
- self.SetIPv6Sysctl(ifname, "use_oif_addrs_only", 0)
-
- # [1] Pick an interface on which to test.
- self.test_netid = random.choice(self.tuns.keys())
- self.test_ip = self.MyAddress(6, self.test_netid)
- self.test_ifindex = self.ifindices[self.test_netid]
- self.test_ifname = self.GetInterfaceName(self.test_netid)
- self.test_lladdr = net_test.GetLinkAddress(self.test_ifname, True)
-
- # [2] Delete the test interface's IPv6 address.
- self.iproute.DelAddress(self.test_ip, 64, self.test_ifindex)
- self.assertAddressNotPresent(self.test_ip)
-
- self.assertAddressNotUsable(self.test_ip, self.test_netid)
- # Verify that the link-local address is not tentative.
- self.assertFalse(self.AddressIsTentative(self.test_lladdr))
-
-
-class TentativeAddressTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testRfc6724Behaviour(self):
- # [3] Get an IPv6 address back, in DAD start-up.
- self.SetDAD(self.test_ifname, 1) # Enable DAD
- # Send a RA to start SLAAC and subsequent DAD.
- self.SendRA(self.test_netid, 0)
- # Get flags and prove tentative-ness.
- self.assertAddressHasExpectedAttributes(
- self.test_ip, self.test_ifindex, iproute.IFA_F_TENTATIVE)
-
- # Even though the interface has an IPv6 address, its tentative nature
- # prevents it from being selected.
- self.assertAddressNotUsable(self.test_ip, self.test_netid)
- self.assertAddressNotSelected(self.test_ip, self.test_netid)
-
- # Busy wait for DAD to complete (should be less than 1 second).
- self.WaitForDad(self.test_ip)
-
- # The test_ip should have completed DAD by now, and should be the
- # chosen source address, eligible to bind to, etc.
- self.assertAddressUsable(self.test_ip, self.test_netid)
- self.assertAddressSelected(self.test_ip, self.test_netid)
-
-
-class OptimisticAddressTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testRfc6724Behaviour(self):
- # [3] Get an IPv6 address back, in optimistic DAD start-up.
- self.SetDAD(self.test_ifname, 1) # Enable DAD
- self.SetOptimisticDAD(self.test_ifname, 1)
- # Send a RA to start SLAAC and subsequent DAD.
- self.SendRA(self.test_netid, 0)
- # Get flags and prove optimism.
- self.assertAddressHasExpectedAttributes(
- self.test_ip, self.test_ifindex, iproute.IFA_F_OPTIMISTIC)
-
- # Optimistic addresses are usable but are not selected.
- if net_test.LinuxVersion() >= (3, 18, 0):
- # The version checked in to android kernels <= 3.10 requires the
- # use_optimistic sysctl to be turned on.
- self.assertAddressUsable(self.test_ip, self.test_netid)
- self.assertAddressNotSelected(self.test_ip, self.test_netid)
-
- # Busy wait for DAD to complete (should be less than 1 second).
- self.WaitForDad(self.test_ip)
-
- # The test_ip should have completed DAD by now, and should be the
- # chosen source address.
- self.assertAddressUsable(self.test_ip, self.test_netid)
- self.assertAddressSelected(self.test_ip, self.test_netid)
-
-
-class OptimisticAddressOkayTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testModifiedRfc6724Behaviour(self):
- # [3] Get an IPv6 address back, in optimistic DAD start-up.
- self.SetDAD(self.test_ifname, 1) # Enable DAD
- self.SetOptimisticDAD(self.test_ifname, 1)
- self.SetUseOptimistic(self.test_ifname, 1)
- # Send a RA to start SLAAC and subsequent DAD.
- self.SendRA(self.test_netid, 0)
- # Get flags and prove optimistism.
- self.assertAddressHasExpectedAttributes(
- self.test_ip, self.test_ifindex, iproute.IFA_F_OPTIMISTIC)
-
- # The interface has an IPv6 address and, despite its optimistic nature,
- # the use_optimistic option allows it to be selected.
- self.assertAddressUsable(self.test_ip, self.test_netid)
- self.assertAddressSelected(self.test_ip, self.test_netid)
-
-
-class ValidBeforeOptimisticTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testModifiedRfc6724Behaviour(self):
- # [3] Add a valid IPv6 address to this interface and verify it is
- # selected as the source address.
- preferred_ip = self.IPv6Prefix(self.test_netid) + "cafe"
- self.iproute.AddAddress(preferred_ip, 64, self.test_ifindex)
- self.assertAddressHasExpectedAttributes(
- preferred_ip, self.test_ifindex, iproute.IFA_F_PERMANENT)
- self.assertEquals(preferred_ip, self.GetSourceIP(self.test_netid))
-
- # [4] Get another IPv6 address, in optimistic DAD start-up.
- self.SetDAD(self.test_ifname, 1) # Enable DAD
- self.SetOptimisticDAD(self.test_ifname, 1)
- self.SetUseOptimistic(self.test_ifname, 1)
- # Send a RA to start SLAAC and subsequent DAD.
- self.SendRA(self.test_netid, 0)
- # Get flags and prove optimism.
- self.assertAddressHasExpectedAttributes(
- self.test_ip, self.test_ifindex, iproute.IFA_F_OPTIMISTIC)
-
- # Since the interface has another IPv6 address, the optimistic address
- # is not selected--the other, valid address is chosen.
- self.assertAddressUsable(self.test_ip, self.test_netid)
- self.assertAddressNotSelected(self.test_ip, self.test_netid)
- self.assertAddressSelected(preferred_ip, self.test_netid)
-
-
-class DadFailureTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testDadFailure(self):
- # [3] Get an IPv6 address back, in optimistic DAD start-up.
- self.SetDAD(self.test_ifname, 1) # Enable DAD
- self.SetOptimisticDAD(self.test_ifname, 1)
- self.SetUseOptimistic(self.test_ifname, 1)
- # Send a RA to start SLAAC and subsequent DAD.
- self.SendRA(self.test_netid, 0)
- # Prove optimism and usability.
- self.assertAddressHasExpectedAttributes(
- self.test_ip, self.test_ifindex, iproute.IFA_F_OPTIMISTIC)
- self.assertAddressUsable(self.test_ip, self.test_netid)
- self.assertAddressSelected(self.test_ip, self.test_netid)
-
- # Send a NA for the optimistic address, indicating address conflict
- # ("DAD defense").
- conflict_macaddr = "02:00:0b:ad:d0:0d"
- dad_defense = (scapy.Ether(src=conflict_macaddr, dst="33:33:33:00:00:01") /
- scapy.IPv6(src=self.test_ip, dst="ff02::1") /
- scapy.ICMPv6ND_NA(tgt=self.test_ip, R=0, S=0, O=1) /
- scapy.ICMPv6NDOptDstLLAddr(lladdr=conflict_macaddr))
- self.ReceiveEtherPacketOn(self.test_netid, dad_defense)
-
- # The address should have failed DAD, and therefore no longer be usable.
- self.assertAddressNotUsable(self.test_ip, self.test_netid)
- self.assertAddressNotSelected(self.test_ip, self.test_netid)
-
- # TODO(ek): verify that an RTM_DELADDR issued for the DAD-failed address.
-
-
-class NoNsFromOptimisticTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testSendToOnlinkDestination(self):
- # [3] Get an IPv6 address back, in optimistic DAD start-up.
- self.SetDAD(self.test_ifname, 1) # Enable DAD
- self.SetOptimisticDAD(self.test_ifname, 1)
- self.SetUseOptimistic(self.test_ifname, 1)
- # Send a RA to start SLAAC and subsequent DAD.
- self.SendRA(self.test_netid, 0)
- # Prove optimism and usability.
- self.assertAddressHasExpectedAttributes(
- self.test_ip, self.test_ifindex, iproute.IFA_F_OPTIMISTIC)
- self.assertAddressUsable(self.test_ip, self.test_netid)
- self.assertAddressSelected(self.test_ip, self.test_netid)
-
- # [4] Send to an on-link destination and observe a Neighbor Solicitation
- # packet with a source address that is NOT the optimistic address.
- # In this setup, the only usable address is the link-local address.
- onlink_dest = self.GetRandomDestination(self.IPv6Prefix(self.test_netid))
- self.SendWithSourceAddress(self.test_ip, self.test_netid, onlink_dest)
-
- if net_test.LinuxVersion() >= (3, 18, 0):
- # Older versions will actually choose the optimistic address to
- # originate Neighbor Solications (RFC violation).
- expected_ns = packets.NS(
- self.test_lladdr,
- onlink_dest,
- self.MyMacAddress(self.test_netid))[1]
- self.ExpectPacketOn(self.test_netid, "link-local NS", expected_ns)
-
-
-# TODO(ek): add tests listening for netlink events.
-
-
-class DefaultCandidateSrcAddrsTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testChoosesNonInterfaceSourceAddress(self):
- self.SetIPv6Sysctl(self.test_ifname, "use_oif_addrs_only", 0)
- src_ip = self.GetSourceIP(self.test_netid)
- self.assertFalse(src_ip in [self.test_ip, self.test_lladdr])
- self.assertTrue(src_ip in
- [self.MyAddress(6, netid)
- for netid in self.tuns if netid != self.test_netid])
-
-
-class RestrictedCandidateSrcAddrsTest(MultiInterfaceSourceAddressSelectionTest):
-
- def testChoosesOnlyInterfaceSourceAddress(self):
- self.SetIPv6Sysctl(self.test_ifname, "use_oif_addrs_only", 1)
- # self.test_ifname does not have a global IPv6 address, so the only
- # candidate is the existing link-local address.
- self.assertAddressSelected(self.test_lladdr, self.test_netid)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/tcp_nuke_addr_test.py b/tests/net_test/tcp_nuke_addr_test.py
deleted file mode 100755
index b0ba27d..0000000
--- a/tests/net_test/tcp_nuke_addr_test.py
+++ /dev/null
@@ -1,250 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2015 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import contextlib
-import errno
-import fcntl
-import resource
-import os
-from socket import * # pylint: disable=wildcard-import
-import struct
-import threading
-import time
-import unittest
-
-import csocket
-import cstruct
-import net_test
-
-IPV4_LOOPBACK_ADDR = "127.0.0.1"
-IPV6_LOOPBACK_ADDR = "::1"
-LOOPBACK_DEV = "lo"
-LOOPBACK_IFINDEX = 1
-
-SIOCKILLADDR = 0x8939
-
-DEFAULT_TCP_PORT = 8001
-DEFAULT_BUFFER_SIZE = 20
-DEFAULT_TEST_MESSAGE = "TCP NUKE ADDR TEST"
-DEFAULT_TEST_RUNS = 100
-HASH_TEST_RUNS = 4000
-HASH_TEST_NOFILE = 16384
-
-
-Ifreq = cstruct.Struct("Ifreq", "=16s16s", "name data")
-In6Ifreq = cstruct.Struct("In6Ifreq", "=16sIi", "addr prefixlen ifindex")
-
-@contextlib.contextmanager
-def RunInBackground(thread):
- """Starts a thread and waits until it joins.
-
- Args:
- thread: A not yet started threading.Thread object.
- """
- try:
- thread.start()
- yield thread
- finally:
- thread.join()
-
-
-def TcpAcceptAndReceive(listening_sock, buffer_size=DEFAULT_BUFFER_SIZE):
- """Accepts a single connection and blocks receiving data from it.
-
- Args:
- listening_socket: A socket in LISTEN state.
- buffer_size: Size of buffer where to read a message.
- """
- connection, _ = listening_sock.accept()
- with contextlib.closing(connection):
- _ = connection.recv(buffer_size)
-
-
-def ExchangeMessage(addr_family, ip_addr):
- """Creates a listening socket, accepts a connection and sends data to it.
-
- Args:
- addr_family: The address family (e.g. AF_INET6).
- ip_addr: The IP address (IPv4 or IPv6 depending on the addr_family).
- tcp_port: The TCP port to listen on.
- """
- # Bind to a random port and connect to it.
- test_addr = (ip_addr, 0)
- with contextlib.closing(
- socket(addr_family, SOCK_STREAM)) as listening_socket:
- listening_socket.bind(test_addr)
- test_addr = listening_socket.getsockname()
- listening_socket.listen(1)
- with RunInBackground(threading.Thread(target=TcpAcceptAndReceive,
- args=(listening_socket,))):
- with contextlib.closing(
- socket(addr_family, SOCK_STREAM)) as client_socket:
- client_socket.connect(test_addr)
- client_socket.send(DEFAULT_TEST_MESSAGE)
-
-
-def KillAddrIoctl(addr):
- """Calls the SIOCKILLADDR ioctl on the provided IP address.
-
- Args:
- addr The IP address to pass to the ioctl.
-
- Raises:
- ValueError: If addr is of an unsupported address family.
- """
- family, _, _, _, _ = getaddrinfo(addr, None, AF_UNSPEC, SOCK_DGRAM, 0,
- AI_NUMERICHOST)[0]
- if family == AF_INET6:
- addr = inet_pton(AF_INET6, addr)
- ifreq = In6Ifreq((addr, 128, LOOPBACK_IFINDEX)).Pack()
- elif family == AF_INET:
- addr = inet_pton(AF_INET, addr)
- sockaddr = csocket.SockaddrIn((AF_INET, 0, addr)).Pack()
- ifreq = Ifreq((LOOPBACK_DEV, sockaddr)).Pack()
- else:
- raise ValueError('Address family %r not supported.' % family)
- datagram_socket = socket(family, SOCK_DGRAM)
- fcntl.ioctl(datagram_socket.fileno(), SIOCKILLADDR, ifreq)
- datagram_socket.close()
-
-
-class ExceptionalReadThread(threading.Thread):
-
- def __init__(self, sock):
- self.sock = sock
- self.exception = None
- super(ExceptionalReadThread, self).__init__()
- self.daemon = True
-
- def run(self):
- try:
- read = self.sock.recv(4096)
- except Exception, e:
- self.exception = e
-
-# For convenience.
-def CreateIPv4SocketPair():
- return net_test.CreateSocketPair(AF_INET, SOCK_STREAM, IPV4_LOOPBACK_ADDR)
-
-def CreateIPv6SocketPair():
- return net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, IPV6_LOOPBACK_ADDR)
-
-
-class TcpNukeAddrTest(net_test.NetworkTest):
-
- def testTimewaitSockets(self):
- """Tests that SIOCKILLADDR works as expected.
-
- Relevant kernel commits:
- https://www.codeaurora.org/cgit/quic/la/kernel/msm-3.18/commit/net/ipv4/tcp.c?h=aosp/android-3.10&id=1dcd3a1fa2fe78251cc91700eb1d384ab02e2dd6
- """
- for i in xrange(DEFAULT_TEST_RUNS):
- ExchangeMessage(AF_INET6, IPV6_LOOPBACK_ADDR)
- KillAddrIoctl(IPV6_LOOPBACK_ADDR)
- ExchangeMessage(AF_INET, IPV4_LOOPBACK_ADDR)
- KillAddrIoctl(IPV4_LOOPBACK_ADDR)
- # Test passes if kernel does not crash.
-
- def testClosesIPv6Sockets(self):
- """Tests that SIOCKILLADDR closes IPv6 sockets and unblocks threads."""
-
- threadpairs = []
-
- for i in xrange(DEFAULT_TEST_RUNS):
- clientsock, acceptedsock = CreateIPv6SocketPair()
- clientthread = ExceptionalReadThread(clientsock)
- clientthread.start()
- serverthread = ExceptionalReadThread(acceptedsock)
- serverthread.start()
- threadpairs.append((clientthread, serverthread))
-
- KillAddrIoctl(IPV6_LOOPBACK_ADDR)
-
- def CheckThreadException(thread):
- thread.join(100)
- self.assertFalse(thread.is_alive())
- self.assertIsNotNone(thread.exception)
- self.assertTrue(isinstance(thread.exception, IOError))
- self.assertEquals(errno.ETIMEDOUT, thread.exception.errno)
- self.assertRaisesErrno(errno.ENOTCONN, thread.sock.getpeername)
- self.assertRaisesErrno(errno.EISCONN, thread.sock.connect,
- (IPV6_LOOPBACK_ADDR, 53))
- self.assertRaisesErrno(errno.EPIPE, thread.sock.send, "foo")
-
- for clientthread, serverthread in threadpairs:
- CheckThreadException(clientthread)
- CheckThreadException(serverthread)
-
- def assertSocketsClosed(self, socketpair):
- for sock in socketpair:
- self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
-
- def assertSocketsNotClosed(self, socketpair):
- for sock in socketpair:
- self.assertTrue(sock.getpeername())
-
- def testAddresses(self):
- socketpair = CreateIPv4SocketPair()
- KillAddrIoctl("::")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("::1")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("127.0.0.3")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("0.0.0.0")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("127.0.0.1")
- self.assertSocketsClosed(socketpair)
-
- socketpair = CreateIPv6SocketPair()
- KillAddrIoctl("0.0.0.0")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("127.0.0.1")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("::2")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("::")
- self.assertSocketsNotClosed(socketpair)
- KillAddrIoctl("::1")
- self.assertSocketsClosed(socketpair)
-
-
-class TcpNukeAddrHashTest(net_test.NetworkTest):
-
- def setUp(self):
- self.nofile = resource.getrlimit(resource.RLIMIT_NOFILE)
- resource.setrlimit(resource.RLIMIT_NOFILE, (HASH_TEST_NOFILE,
- HASH_TEST_NOFILE))
-
- def tearDown(self):
- resource.setrlimit(resource.RLIMIT_NOFILE, self.nofile)
-
- def testClosesAllSockets(self):
- socketpairs = []
- for i in xrange(HASH_TEST_RUNS):
- socketpairs.append(CreateIPv4SocketPair())
- socketpairs.append(CreateIPv6SocketPair())
-
- KillAddrIoctl(IPV4_LOOPBACK_ADDR)
- KillAddrIoctl(IPV6_LOOPBACK_ADDR)
-
- for socketpair in socketpairs:
- for sock in socketpair:
- self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/net_test/tcp_test.py b/tests/net_test/tcp_test.py
deleted file mode 100644
index 81a6884..0000000
--- a/tests/net_test/tcp_test.py
+++ /dev/null
@@ -1,124 +0,0 @@
-#!/usr/bin/python
-#
-# Copyright 2015 The Android Open Source Project
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import time
-from socket import * # pylint: disable=wildcard-import
-
-import net_test
-import multinetwork_base
-import packets
-
-# TCP states. See include/net/tcp_states.h.
-TCP_ESTABLISHED = 1
-TCP_SYN_SENT = 2
-TCP_SYN_RECV = 3
-TCP_FIN_WAIT1 = 4
-TCP_FIN_WAIT2 = 5
-TCP_TIME_WAIT = 6
-TCP_CLOSE = 7
-TCP_CLOSE_WAIT = 8
-TCP_LAST_ACK = 9
-TCP_LISTEN = 10
-TCP_CLOSING = 11
-TCP_NEW_SYN_RECV = 12
-
-TCP_NOT_YET_ACCEPTED = -1
-
-
-class TcpBaseTest(multinetwork_base.MultiNetworkBaseTest):
-
- def tearDown(self):
- if hasattr(self, "s"):
- self.s.close()
- super(TcpBaseTest, self).tearDown()
-
- def OpenListenSocket(self, version, netid):
- self.port = packets.RandomPort()
- family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
- address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
- s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
- s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
- s.bind((address, self.port))
- # We haven't configured inbound iptables marking, so bind explicitly.
- self.SelectInterface(s, netid, "mark")
- s.listen(100)
- return s
-
- def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
- pkt = super(TcpBaseTest, self)._ReceiveAndExpectResponse(netid, packet,
- reply, msg)
- self.last_packet = pkt
- return pkt
-
- def ReceivePacketOn(self, netid, packet):
- super(TcpBaseTest, self).ReceivePacketOn(netid, packet)
- self.last_packet = packet
-
- def RstPacket(self):
- return packets.RST(self.version, self.myaddr, self.remoteaddr,
- self.last_packet)
-
- def IncomingConnection(self, version, end_state, netid):
- self.s = self.OpenListenSocket(version, netid)
- self.end_state = end_state
-
- remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
- myaddr = self.myaddr = self.MyAddress(version, netid)
-
- if version == 5: version = 4
- self.version = version
-
- if end_state == TCP_LISTEN:
- return
-
- desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
- synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
- msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
- reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
- if end_state == TCP_SYN_RECV:
- return
-
- establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
- self.ReceivePacketOn(netid, establishing_ack)
-
- if end_state == TCP_NOT_YET_ACCEPTED:
- return
-
- self.accepted, _ = self.s.accept()
- net_test.DisableLinger(self.accepted)
-
- if end_state == TCP_ESTABLISHED:
- return
-
- desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
- payload=net_test.UDP_PAYLOAD)
- self.accepted.send(net_test.UDP_PAYLOAD)
- self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
-
- desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
- fin = packets._GetIpLayer(version)(str(fin))
- ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
- msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
-
- # TODO: Why can't we use this?
- # self._ReceiveAndExpectResponse(netid, fin, ack, msg)
- self.ReceivePacketOn(netid, fin)
- time.sleep(0.1)
- self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
- if end_state == TCP_CLOSE_WAIT:
- return
-
- raise ValueError("Invalid TCP state %d specified" % end_state)