autotest: gs_cache_client: handle extracting of large number of files
am: 871ba53a9d

Change-Id: Ice1b1caf1915bb1f0645e0e89aee4bad5c5a048d
diff --git a/client/common_lib/cros/gs_cache_client.py b/client/common_lib/cros/gs_cache_client.py
index e0b46b6..340d879 100644
--- a/client/common_lib/cros/gs_cache_client.py
+++ b/client/common_lib/cros/gs_cache_client.py
@@ -29,6 +29,7 @@
 from autotest_lib.client.common_lib import utils
 from autotest_lib.client.common_lib.cros import dev_server
 from autotest_lib.client.common_lib.cros import retry
+from autotest_lib.client.common_lib.cros import string_utils
 
 from chromite.lib import metrics
 
@@ -47,6 +48,7 @@
 _SSH_CALL_TIMEOUT_SECONDS = 60
 
 _MESSAGE_LENGTH_MAX_CHARS = 200
+_MAX_URL_QUERY_LENGTH = 4096
 
 # Exit code of `curl` when cannot connect to host. Man curl for details.
 _CURL_RC_CANNOT_CONNECT_TO_HOST = 7
@@ -54,6 +56,13 @@
 METRICS_PATH = 'chromeos/autotest/gs_cache_client'
 
 
+def _truncate_long_message(message):
+    """Truncate too long message (e.g. url) to limited length."""
+    if len(message) > _MESSAGE_LENGTH_MAX_CHARS:
+        message = '%s...' % message[:_MESSAGE_LENGTH_MAX_CHARS]
+    return message
+
+
 class CommunicationError(Exception):
     """Raised when has errors in communicate with GS Cache server."""
 
@@ -96,7 +105,7 @@
         @throws CommunicationError when the SSH command failed.
         """
         cmd = 'ssh %s \'curl "%s"\'' % (self._hostname, utils.sh_escape(url))
-        logging.debug('Gs Cache call: %s', cmd)
+        logging.debug('Gs Cache call: %s', _truncate_long_message(cmd))
         try:
             result = utils.run(cmd, timeout=_SSH_CALL_TIMEOUT_SECONDS)
         except error.CmdError as err:
@@ -140,14 +149,14 @@
         if _USE_SSH_CONNECTION and self._is_in_restricted_subnet:
             return self._ssh_call(url)
         else:
-            logging.debug('Gs Cache call: %s', url)
+            logging.debug('Gs Cache call: %s', _truncate_long_message(url))
             # TODO(guocb): Re-use the TCP connection.
             try:
                 rsp = requests.get(url)
-            except requests.ConnectionError:
+            except requests.ConnectionError as err:
                 raise NoGsCacheServerError(
-                        'Cannot connect to Gs Cache at %s via HTTP.'
-                        % self._netloc)
+                        'Cannot connect to Gs Cache at %s via HTTP: %s'
+                        % (self._netloc, err))
             if not rsp.ok:
                 msg = 'HTTP request: GET %s\nHTTP Response: %d: %s' % (
                         rsp.url, rsp.status_code, rsp.content)
@@ -159,19 +168,29 @@
 
         @param bucket: The bucket of the file on GS.
         @param archive: The path of archive on GS (bucket part not included).
-        @param files: A path, or a path list of files to be extracted.
+        @param files: A list of files to be extracted.
 
         @return A dict of extracted files, in format of
                 {filename: content, ...}.
         @throws ResponseContentError if the response is not in JSON format.
         """
-        rsp_content = self._call('extract', bucket, archive, {'file': files})
+        rsp_contents = []
+        # The files to be extract may be too many which reuslts in too long URL
+        # and http server may responses with 414 error. So we split them into
+        # multiple requests if necessary.
+        for part_of_files in string_utils.join_longest_with_length_limit(
+                files, _MAX_URL_QUERY_LENGTH, separator='&file=',
+                do_join=False):
+            rsp_contents.append(self._call('extract', bucket, archive,
+                                           {'file': part_of_files}))
+        content_dict = {}
         try:
-            content_dict = json.loads(rsp_content)
+            for content in rsp_contents:
+                content_dict.update(json.loads(content))
         except ValueError as err:
             raise ResponseContentError(
                 'Got ValueError "%s" when decoding to JSON format. The '
-                'response content is: %s' % (err, rsp_content))
+                'response content is: %s' % (err, rsp_contents))
 
         return content_dict
 
@@ -253,14 +272,9 @@
         try:
             map_file_content = content_dict[map_file_name]
         except KeyError:
-            # content_dict may have too many keys which makes the exception
-            # message less readable. So truncate it to reasonable length.
-            content_dict_str = str(content_dict)
-            if len(content_dict_str) > _MESSAGE_LENGTH_MAX_CHARS:
-                content_dict_str = (
-                        '%s...' % content_dict_str[:_MESSAGE_LENGTH_MAX_CHARS])
-            raise ResponseContentError("File '%s' isn't in response: %s" %
-                                       (map_file_name, content_dict_str))
+            raise ResponseContentError(
+                    "File '%s' isn't in response: %s" %
+                    (map_file_name, _truncate_long_message(str(content_dict))))
         try:
             suite_to_control_files = json.loads(map_file_content)
         except ValueError as err:
diff --git a/client/common_lib/cros/gs_cache_client_unittest.py b/client/common_lib/cros/gs_cache_client_unittest.py
index 41d80a6..df56ae9 100755
--- a/client/common_lib/cros/gs_cache_client_unittest.py
+++ b/client/common_lib/cros/gs_cache_client_unittest.py
@@ -48,6 +48,16 @@
                     'file')
             self.assertEqual(result, {})
 
+    def test_extract_many_files_via_http(self):
+        """Test extracting many files via http."""
+        with mock.patch('requests.get') as m:
+            m.return_value = mock.MagicMock(ok=True, content='{}')
+            result = self.api.extract(
+                    gs_cache_client._CROS_IMAGE_ARCHIVE_BUCKET, 'archive',
+                    ['the_file'] * 1000)
+            self.assertEqual(result, {})
+            self.assertTrue(m.call_count > 1)
+
     @mock.patch('time.sleep')
     @mock.patch('time.time', side_effect=itertools.cycle([0, 400]))
     def test_extract_via_ssh_has_error(self, *args):
diff --git a/client/common_lib/cros/string_utils.py b/client/common_lib/cros/string_utils.py
index af5af14..2f13406 100644
--- a/client/common_lib/cros/string_utils.py
+++ b/client/common_lib/cros/string_utils.py
@@ -3,8 +3,7 @@
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
 
-"""A collection of classes/functions to manipulate strings.  """
-
+"""A collection of classes/functions to manipulate strings."""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -16,7 +15,8 @@
     """Raised when string is too long to manipulate."""
 
 
-def join_longest_with_length_limit(string_list, length_limit, separator=''):
+def join_longest_with_length_limit(string_list, length_limit, separator='',
+                                   do_join=True):
     """Join strings to meet length limit and yield results.
 
     Join strings from |string_list| using |separator| and yield the results.
@@ -24,13 +24,15 @@
     |length_limit|. In other words, this function yields minimum number of
     result strings.
 
-    An error will be raised when any stirng in |string_list| is longer than
+    An error will be raised when any string in |string_list| is longer than
     |length_limit| because the result string joined must be longer than
     |length_limit| in any case.
 
     @param string_list: A list of strings to be joined.
     @param length_limit: The maximum length of the result string.
     @param separator: The separator to join strings.
+    @param do_join: join the result list to string if True, else just return the
+        result list.
 
     @yield The result string.
     @throws StringTooLongError when any string in |string_list| is longer than
@@ -51,22 +53,22 @@
     length_limit += len_sep
     # Call str.join directly when possible.
     if sum(length_list) + len_sep * len(string_list) <= length_limit:
-        yield separator.join(string_list)
+        yield separator.join(string_list) if do_join else string_list
         return
 
-    result = ''
+    result = []
     new_length_limit = length_limit
     while string_list:
         index = bisect.bisect_right(length_list,
                                     new_length_limit - len_sep) - 1
         if index < 0:  # All available strings are longer than the limit.
-            yield result[:-len_sep]
-            result = ''
+            yield separator.join(result) if do_join else result
+            result = []
             new_length_limit = length_limit
             continue
 
-        result = '%s%s%s' % (result, string_list.pop(index), separator)
+        result.append(string_list.pop(index))
         new_length_limit -= length_list.pop(index) + len_sep
 
     if result:
-        yield result[:-len_sep]
+        yield separator.join(result) if do_join else result
diff --git a/client/common_lib/cros/string_utils_unittest.py b/client/common_lib/cros/string_utils_unittest.py
index 11a7822..3d812d7 100755
--- a/client/common_lib/cros/string_utils_unittest.py
+++ b/client/common_lib/cros/string_utils_unittest.py
@@ -18,20 +18,23 @@
 
 class JoinLongestWithLengthLimitTest(unittest.TestCase):
     """Unit test of join_longest_with_length_limit."""
+    def setUp(self):
+        """Setup."""
+        self.strings = ['abc', '12', 'sssss']
+
     def test_basic(self):
         """The basic test case."""
-        strings = ['abc', '12', 'sssss']
         result = list(string_utils.join_longest_with_length_limit(
-                strings, 6, separator=','))
+                self.strings, 6, separator=','))
         self.assertEqual(len(result), 2)
+        self.assertTrue(type(result[0]) is str)
 
     def test_short_strings(self):
         """Test with short strings to be joined with big limit."""
-        strings = ['abc', '12', 'sssss']
         sep = mock.MagicMock()
         result = list(string_utils.join_longest_with_length_limit(
-                strings, 100, separator=sep))
-        sep.join.assert_called()
+                self.strings, 100, separator=sep))
+        sep.join.assert_called_once()
 
     def test_string_too_long(self):
         """Test with too long string to be joined."""
@@ -41,13 +44,19 @@
 
     def test_long_sep(self):
         """Test with long seperator."""
-        strings = ['abc', '12', 'sssss']
         result = list(string_utils.join_longest_with_length_limit(
-                strings, 6, separator='|very long separator|'))
+                self.strings, 6, separator='|very long separator|'))
         # Though the string to be joined is short, we still will have 3 result
         # because each of them plus separator is longer than the limit.
         self.assertEqual(len(result), 3)
 
+    def test_do_not_join(self):
+        """Test yielding list instead of string."""
+        result = list(string_utils.join_longest_with_length_limit(
+                self.strings, 6, separator=',', do_join=False))
+        self.assertEqual(len(result), 2)
+        self.assertTrue(type(result[0]) is list)
+
 
 if __name__ == '__main__':
     unittest.main()