Fix up checksums instead of recalculating them.

Currently the checksums of translated packets are calculated
from scratch by checksumming the translated packet. This is slow
and does not work in the case of fragments, because the whole
packet is not available. Instead, calculate the checksum by
adjusting the checksum of the original packet.

Bug: 11542311
Bug: 12116252
Change-Id: I6b78a94ca5bd96b13ee2653b6200551193b3dcc1
diff --git a/checksum.c b/checksum.c
index a4dc9b8..099be6a 100644
--- a/checksum.c
+++ b/checksum.c
@@ -49,19 +49,27 @@
   return checksum;
 }
 
-/* function: ip_checksum_finish
- * close the checksum
+/* function: ip_checksum_fold
+ * folds a 32-bit partial checksum into 16 bits
  * temp_sum - sum from ip_checksum_add
+ * returns: the folded checksum in network byte order
  */
-uint16_t ip_checksum_finish(uint32_t temp_sum) {
+uint16_t ip_checksum_fold(uint32_t temp_sum) {
   while(temp_sum > 0xffff)
     temp_sum = (temp_sum >> 16) + (temp_sum & 0xFFFF);
 
-  temp_sum = (~temp_sum) & 0xffff;
-
   return temp_sum;
 }
 
+/* function: ip_checksum_finish
+ * folds and closes the checksum
+ * temp_sum - sum from ip_checksum_add
+ * returns: a header checksum value in network byte order
+ */
+uint16_t ip_checksum_finish(uint32_t temp_sum) {
+  return ~ip_checksum_fold(temp_sum);
+}
+
 /* function: ip_checksum
  * combined ip_checksum_add and ip_checksum_finish
  * data - data to checksum
@@ -113,3 +121,23 @@
 
   return current;
 }
+
+/* function: ip_checksum_adjust
+ * calculates a new checksum given a previous checksum and the old and new pseudo-header checksums
+ * checksum    - the header checksum in the original packet in network byte order
+ * old_hdr_sum - the pseudo-header checksum of the original packet
+ * new_hdr_sum - the pseudo-header checksum of the translated packet
+ * returns: the new header checksum in network byte order
+ */
+uint16_t ip_checksum_adjust(uint16_t checksum, uint32_t old_hdr_sum, uint32_t new_hdr_sum) {
+  // Algorithm suggested in RFC 1624.
+  // http://tools.ietf.org/html/rfc1624#section-3
+  checksum = ~checksum;
+  uint16_t folded_sum = ip_checksum_fold(checksum + new_hdr_sum);
+  uint16_t folded_old = ip_checksum_fold(old_hdr_sum);
+  if (folded_sum > folded_old) {
+    return ~(folded_sum - folded_old);
+  } else {
+    return ~(folded_sum - folded_old - 1);  // end-around borrow
+  }
+}
diff --git a/checksum.h b/checksum.h
index 473f5f5..44921f0 100644
--- a/checksum.h
+++ b/checksum.h
@@ -25,4 +25,6 @@
 uint32_t ipv6_pseudo_header_checksum(uint32_t current, const struct ip6_hdr *ip6, uint16_t len);
 uint32_t ipv4_pseudo_header_checksum(uint32_t current, const struct iphdr *ip, uint16_t len);
 
+uint16_t ip_checksum_adjust(uint16_t checksum, uint32_t old_hdr_sum, uint32_t new_hdr_sum);
+
 #endif /* __CHECKSUM_H__ */
diff --git a/ipv4.c b/ipv4.c
index b5cbf80..1d5b0b2 100644
--- a/ipv4.c
+++ b/ipv4.c
@@ -70,7 +70,7 @@
   uint8_t nxthdr;
   const char *next_header;
   size_t len_left;
-  uint32_t checksum;
+  uint32_t old_sum, new_sum;
   int iov_len;
 
   if(len < sizeof(struct iphdr)) {
@@ -121,14 +121,17 @@
   out[pos].iov_len = sizeof(struct ip6_hdr);
 
   // Calculate the pseudo-header checksum.
-  checksum = ipv6_pseudo_header_checksum(0, ip6_targ, len_left);
+  old_sum = ipv4_pseudo_header_checksum(0, header, len_left);
+  new_sum = ipv6_pseudo_header_checksum(0, ip6_targ, len_left);
 
   if (nxthdr == IPPROTO_ICMPV6) {
-    iov_len = icmp_packet(out, pos + 1, (const struct icmphdr *) next_header, checksum, len_left);
+    iov_len = icmp_packet(out, pos + 1, (const struct icmphdr *) next_header, new_sum, len_left);
   } else if (nxthdr == IPPROTO_TCP) {
-    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, checksum, len_left);
+    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, old_sum, new_sum,
+                         len_left);
   } else if (nxthdr == IPPROTO_UDP) {
-    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, checksum, len_left);
+    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, old_sum, new_sum,
+                         len_left);
   } else if (nxthdr == IPPROTO_GRE) {
     iov_len = generic_packet(out, pos + 1, next_header, len_left);
   } else {
diff --git a/ipv6.c b/ipv6.c
index 79303ec..e4a73fe 100644
--- a/ipv6.c
+++ b/ipv6.c
@@ -88,7 +88,7 @@
   uint8_t protocol;
   const char *next_header;
   size_t len_left;
-  uint32_t checksum;
+  uint32_t old_sum, new_sum;
   int iov_len;
 
   if(len < sizeof(struct ip6_hdr)) {
@@ -133,16 +133,17 @@
   out[pos].iov_len = sizeof(struct iphdr);
 
   // Calculate the pseudo-header checksum.
-  checksum = ipv4_pseudo_header_checksum(0, ip_targ, len_left);
+  old_sum = ipv6_pseudo_header_checksum(0, ip6, len_left);
+  new_sum = ipv4_pseudo_header_checksum(0, ip_targ, len_left);
 
   // does not support IPv6 extension headers, this will drop any packet with them
   if (protocol == IPPROTO_ICMP) {
     iov_len = icmp6_packet(out, pos + 1, (const struct icmp6_hdr *) next_header, len_left);
   } else if (ip6->ip6_nxt == IPPROTO_TCP) {
-    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, checksum,
+    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, old_sum, new_sum,
                          len_left);
   } else if (ip6->ip6_nxt == IPPROTO_UDP) {
-    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, checksum,
+    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, old_sum, new_sum,
                          len_left);
   } else if (ip6->ip6_nxt == IPPROTO_GRE) {
     iov_len = generic_packet(out, pos + 1, next_header, len_left);
diff --git a/translate.c b/translate.c
index 00ea0b9..9a0f1b5 100644
--- a/translate.c
+++ b/translate.c
@@ -208,12 +208,10 @@
     // The pseudo-header checksum was calculated on the transport length of the original IPv4
     // packet that we were asked to translate. This transport length is 20 bytes smaller than it
     // needs to be, because the ICMP error contains an IPv4 header, which we will be translating to
-    // an IPv6 header, which is 20 bytes longer. Fix it up here. This is simpler than the
-    // alternative, which is to always update the pseudo-header checksum in all UDP/TCP/ICMP
-    // translation functions (rather than pre-calculating it when translating the IPv4 header).
+    // an IPv6 header, which is 20 bytes longer. Fix it up here.
     // We only need to do this for ICMP->ICMPv6, not ICMPv6->ICMP, because ICMP does not use the
     // pseudo-header when calculating its checksum (as the IPv4 header has its own checksum).
-    checksum = htonl(ntohl(checksum) + 20);
+    checksum = checksum + htons(20);
   } else if (icmp6_type == ICMP6_ECHO_REQUEST || icmp6_type == ICMP6_ECHO_REPLY) {
     // Ping packet.
     icmp6_targ->icmp6_id = icmp->un.echo.id;
@@ -298,10 +296,12 @@
  * takes a udp packet and sets it up for translation
  * out      - output packet
  * udp      - pointer to udp header in packet
- * checksum - pseudo-header checksum
+ * old_sum  - pseudo-header checksum of old header
+ * new_sum  - pseudo-header checksum of new header
  * len      - size of ip payload
  */
-int udp_packet(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum, size_t len) {
+int udp_packet(clat_packet out, int pos, const struct udphdr *udp,
+               uint32_t old_sum, uint32_t new_sum, size_t len) {
   const char *payload;
   size_t payload_size;
 
@@ -313,7 +313,7 @@
   payload = (const char *) (udp + 1);
   payload_size = len - sizeof(struct udphdr);
 
-  return udp_translate(out, pos, udp, checksum, payload, payload_size);
+  return udp_translate(out, pos, udp, old_sum, new_sum, payload, payload_size);
 }
 
 /* function: tcp_packet
@@ -324,7 +324,8 @@
  * len      - size of ip payload
  * returns: the highest position in the output clat_packet that's filled in
  */
-int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp, uint32_t checksum, size_t len) {
+int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp,
+               uint32_t old_sum, uint32_t new_sum, size_t len) {
   const char *payload;
   size_t payload_size, header_size;
 
@@ -347,20 +348,21 @@
   payload = ((const char *) tcp) + header_size;
   payload_size = len - header_size;
 
-  return tcp_translate(out, pos, tcp, header_size, checksum, payload, payload_size);
+  return tcp_translate(out, pos, tcp, header_size, old_sum, new_sum, payload, payload_size);
 }
 
 /* function: udp_translate
  * common between ipv4/ipv6 - setup checksum and send udp packet
  * out          - output packet
  * udp          - udp header
- * checksum     - pseudo-header checksum
+ * old_sum      - pseudo-header checksum of old header
+ * new_sum      - pseudo-header checksum of new header
  * payload      - tcp payload
  * payload_size - size of payload
  * returns: the highest position in the output clat_packet that's filled in
  */
-int udp_translate(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum,
-                  const char *payload, size_t payload_size) {
+int udp_translate(clat_packet out, int pos, const struct udphdr *udp, uint32_t old_sum,
+                  uint32_t new_sum, const char *payload, size_t payload_size) {
   struct udphdr *udp_targ = out[pos].iov_base;
 
   memcpy(udp_targ, udp, sizeof(struct udphdr));
@@ -369,8 +371,22 @@
   out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
   out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  udp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
-  udp_targ->check = packet_checksum(checksum, out, pos);
+  if (udp_targ->check) {
+    udp_targ->check = ip_checksum_adjust(udp->check, old_sum, new_sum);
+  } else {
+    // Zero checksums are special. RFC 768 says, "An all zero transmitted checksum value means that
+    // the transmitter generated no checksum (for debugging or for higher level protocols that
+    // don't care)." However, in IPv6 zero UDP checksums were only permitted by RFC 6935 (2013). So
+    // for safety we recompute it.
+    udp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
+    udp_targ->check = packet_checksum(new_sum, out, pos);
+  }
+
+  // RFC 768: "If the computed checksum is zero, it is transmitted as all ones (the equivalent
+  // in one's complement arithmetic)."
+  if (!udp_targ->check) {
+    udp_targ->check = 0xffff;
+  }
 
   return CLAT_POS_PAYLOAD + 1;
 }
@@ -389,7 +405,7 @@
  * TODO: hosts without pmtu discovery - non DF packets will rely on fragmentation (unimplemented)
  */
 int tcp_translate(clat_packet out, int pos, const struct tcphdr *tcp, size_t header_size,
-                  uint32_t checksum, const char *payload, size_t payload_size) {
+                  uint32_t old_sum, uint32_t new_sum, const char *payload, size_t payload_size) {
   struct tcphdr *tcp_targ = out[pos].iov_base;
   out[pos].iov_len = header_size;
 
@@ -406,8 +422,7 @@
   out[CLAT_POS_PAYLOAD].iov_base = (char *)payload;
   out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  tcp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
-  tcp_targ->check = packet_checksum(checksum, out, pos);
+  tcp_targ->check = ip_checksum_adjust(tcp->check, old_sum, new_sum);
 
   return CLAT_POS_PAYLOAD + 1;
 }
diff --git a/translate.h b/translate.h
index 9f1ac15..cfb7bbbf 100644
--- a/translate.h
+++ b/translate.h
@@ -61,12 +61,14 @@
 int generic_packet(clat_packet out, int pos, const char *payload, size_t len);
 
 // Translate TCP and UDP packets.
-int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp, uint32_t checksum, size_t len);
-int udp_packet(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum, size_t len);
+int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp,
+               uint32_t old_sum, uint32_t new_sum, size_t len);
+int udp_packet(clat_packet out, int pos, const struct udphdr *udp,
+               uint32_t old_sum, uint32_t new_sum, size_t len);
 
 int tcp_translate(clat_packet out, int pos, const struct tcphdr *tcp, size_t header_size,
-                  uint32_t checksum, const char *payload, size_t payload_size);
-int udp_translate(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum,
-                  const char *payload, size_t payload_size);
+                  uint32_t old_sum, uint32_t new_sum, const char *payload, size_t payload_size);
+int udp_translate(clat_packet out, int pos, const struct udphdr *udp,
+                  uint32_t old_sum, uint32_t new_sum, const char *payload, size_t payload_size);
 
 #endif /* __TRANSLATE_H__ */