[Ruby] Message.decode/encode: Add max_recursion_depth option (#9218)

* Message.decode/encode: Add max_recursion_depth option

This allows increasing the recursing depth from the default of 64, by
setting the "max_recursion_depth" to the desired integer value. This is
useful to encode or decode complex nested protobuf messages that otherwise
error out with a RuntimeError or "Error occurred during parsing".

Fixes #1493

* Address review comments

Co-authored-by: Adam Cozzette <acozzette@google.com>
diff --git a/ruby/ext/google/protobuf_c/message.c b/ruby/ext/google/protobuf_c/message.c
index 7feee75..8d25b79 100644
--- a/ruby/ext/google/protobuf_c/message.c
+++ b/ruby/ext/google/protobuf_c/message.c
@@ -953,13 +953,35 @@
 
 /*
  * call-seq:
- *     MessageClass.decode(data) => message
+ *     MessageClass.decode(data, options) => message
  *
  * Decodes the given data (as a string containing bytes in protocol buffers wire
  * format) under the interpretration given by this message class's definition
  * and returns a message object with the corresponding field values.
+ * @param options [Hash] options for the decoder
+ *  max_recursion_depth: set to maximum decoding depth for message (default is 64)
  */
-static VALUE Message_decode(VALUE klass, VALUE data) {
+static VALUE Message_decode(int argc, VALUE* argv, VALUE klass) {
+  VALUE data = argv[0];
+  int options = 0;
+
+  if (argc < 1 || argc > 2) {
+    rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
+  }
+
+  if (argc == 2) {
+    VALUE hash_args = argv[1];
+    if (TYPE(hash_args) != T_HASH) {
+      rb_raise(rb_eArgError, "Expected hash arguments.");
+    }
+
+    VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));
+
+    if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
+      options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
+    }
+  }
+
   if (TYPE(data) != T_STRING) {
     rb_raise(rb_eArgError, "Expected string for binary protobuf data.");
   }
@@ -969,7 +991,7 @@
 
   upb_DecodeStatus status = upb_Decode(
       RSTRING_PTR(data), RSTRING_LEN(data), (upb_Message*)msg->msg,
-      upb_MessageDef_MiniTable(msg->msgdef), NULL, 0, Arena_get(msg->arena));
+      upb_MessageDef_MiniTable(msg->msgdef), NULL, options, Arena_get(msg->arena));
 
   if (status != kUpb_DecodeStatus_Ok) {
     rb_raise(cParseError, "Error occurred during parsing");
@@ -1043,24 +1065,43 @@
 
 /*
  * call-seq:
- *     MessageClass.encode(msg) => bytes
+ *     MessageClass.encode(msg, options) => bytes
  *
  * Encodes the given message object to its serialized form in protocol buffers
  * wire format.
+ * @param options [Hash] options for the encoder
+ *  max_recursion_depth: set to maximum encoding depth for message (default is 64)
  */
-static VALUE Message_encode(VALUE klass, VALUE msg_rb) {
-  Message* msg = ruby_to_Message(msg_rb);
+static VALUE Message_encode(int argc, VALUE* argv, VALUE klass) {
+  Message* msg = ruby_to_Message(argv[0]);
+  int options = 0;
   const char* data;
   size_t size;
 
-  if (CLASS_OF(msg_rb) != klass) {
+  if (CLASS_OF(argv[0]) != klass) {
     rb_raise(rb_eArgError, "Message of wrong type.");
   }
 
-  upb_Arena* arena = upb_Arena_New();
+  if (argc < 1 || argc > 2) {
+    rb_raise(rb_eArgError, "Expected 1 or 2 arguments.");
+  }
 
-  data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef), 0, arena,
-                    &size);
+  if (argc == 2) {
+    VALUE hash_args = argv[1];
+    if (TYPE(hash_args) != T_HASH) {
+      rb_raise(rb_eArgError, "Expected hash arguments.");
+    }
+    VALUE depth = rb_hash_lookup(hash_args, ID2SYM(rb_intern("max_recursion_depth")));
+
+    if (depth != Qnil && TYPE(depth) == T_FIXNUM) {
+      options |= UPB_DECODE_MAXDEPTH(FIX2INT(depth));
+    }
+  }
+
+  upb_Arena *arena = upb_Arena_New();
+
+  data = upb_Encode(msg->msg, upb_MessageDef_MiniTable(msg->msgdef),
+                    options, arena, &size);
 
   if (data) {
     VALUE ret = rb_str_new(data, size);
@@ -1186,8 +1227,8 @@
   rb_define_method(klass, "to_s", Message_inspect, 0);
   rb_define_method(klass, "[]", Message_index, 1);
   rb_define_method(klass, "[]=", Message_index_set, 2);
-  rb_define_singleton_method(klass, "decode", Message_decode, 1);
-  rb_define_singleton_method(klass, "encode", Message_encode, 1);
+  rb_define_singleton_method(klass, "decode", Message_decode, -1);
+  rb_define_singleton_method(klass, "encode", Message_encode, -1);
   rb_define_singleton_method(klass, "decode_json", Message_decode_json, -1);
   rb_define_singleton_method(klass, "encode_json", Message_encode_json, -1);
   rb_define_singleton_method(klass, "descriptor", Message_descriptor, 0);
diff --git a/ruby/lib/google/protobuf.rb b/ruby/lib/google/protobuf.rb
index f939a4c..b7a6711 100644
--- a/ruby/lib/google/protobuf.rb
+++ b/ruby/lib/google/protobuf.rb
@@ -59,16 +59,16 @@
 module Google
   module Protobuf
 
-    def self.encode(msg)
-      msg.to_proto
+    def self.encode(msg, options = {})
+      msg.to_proto(options)
     end
 
     def self.encode_json(msg, options = {})
       msg.to_json(options)
     end
 
-    def self.decode(klass, proto)
-      klass.decode(proto)
+    def self.decode(klass, proto, options = {})
+      klass.decode(proto, options)
     end
 
     def self.decode_json(klass, json, options = {})
diff --git a/ruby/lib/google/protobuf/message_exts.rb b/ruby/lib/google/protobuf/message_exts.rb
index f432f89..6608521 100644
--- a/ruby/lib/google/protobuf/message_exts.rb
+++ b/ruby/lib/google/protobuf/message_exts.rb
@@ -44,8 +44,8 @@
         self.class.encode_json(self, options)
       end
 
-      def to_proto
-        self.class.encode(self)
+      def to_proto(options = {})
+        self.class.encode(self, options)
       end
 
     end
diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
index f7379b1..b5e4903 100644
--- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
+++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMap.java
@@ -389,7 +389,7 @@
         return newMap;
     }
 
-    protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth) {
+    protected List<DynamicMessage> build(ThreadContext context, RubyDescriptor descriptor, int depth, int maxRecursionDepth) {
         List<DynamicMessage> list = new ArrayList<DynamicMessage>();
         RubyClass rubyClass = (RubyClass) descriptor.msgclass(context);
         FieldDescriptor keyField = descriptor.getField("key");
@@ -398,7 +398,7 @@
             RubyMessage mapMessage = (RubyMessage) rubyClass.newInstance(context, Block.NULL_BLOCK);
             mapMessage.setField(context, keyField, key);
             mapMessage.setField(context, valueField, table.get(key));
-            list.add(mapMessage.build(context, depth + 1));
+            list.add(mapMessage.build(context, depth + 1, maxRecursionDepth));
         }
         return list;
     }
diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
index cf59f62..2ba132e 100644
--- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
+++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
@@ -39,6 +39,7 @@
 import com.google.protobuf.Descriptors.FileDescriptor;
 import com.google.protobuf.Descriptors.OneofDescriptor;
 import com.google.protobuf.ByteString;
+import com.google.protobuf.CodedInputStream;
 import com.google.protobuf.DynamicMessage;
 import com.google.protobuf.InvalidProtocolBufferException;
 import com.google.protobuf.Message;
@@ -461,35 +462,63 @@
 
     /*
      * call-seq:
-     *     MessageClass.encode(msg) => bytes
+     *     MessageClass.encode(msg, options = {}) => bytes
      *
      * Encodes the given message object to its serialized form in protocol buffers
      * wire format.
+     * @param options [Hash] options for the encoder
+     *  max_recursion_depth: set to maximum encoding depth for message (default is 64)
      */
-    @JRubyMethod(meta = true)
-    public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject value) {
-        if (recv != value.getMetaClass()) {
-            throw context.runtime.newArgumentError("Tried to encode a " + value.getMetaClass() + " message with " + recv);
+    @JRubyMethod(required = 1, optional = 1, meta = true)
+    public static IRubyObject encode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
+        if (recv != args[0].getMetaClass()) {
+            throw context.runtime.newArgumentError("Tried to encode a " + args[0].getMetaClass() + " message with " + recv);
         }
-        RubyMessage message = (RubyMessage) value;
-        return context.runtime.newString(new ByteList(message.build(context).toByteArray()));
+        RubyMessage message = (RubyMessage) args[0];
+        int maxRecursionDepthInt = SINK_MAXIMUM_NESTING;
+
+        if (args.length > 1) {
+            RubyHash options = (RubyHash) args[1];
+            IRubyObject maxRecursionDepth = options.fastARef(context.runtime.newSymbol("max_recursion_depth"));
+
+            if (maxRecursionDepth != null) {
+                maxRecursionDepthInt = ((RubyNumeric) maxRecursionDepth).getIntValue();
+            }
+        }
+        return context.runtime.newString(new ByteList(message.build(context, 0, maxRecursionDepthInt).toByteArray()));
     }
 
     /*
      * call-seq:
-     *     MessageClass.decode(data) => message
+     *     MessageClass.decode(data, options = {}) => message
      *
      * Decodes the given data (as a string containing bytes in protocol buffers wire
      * format) under the interpretation given by this message class's definition
      * and returns a message object with the corresponding field values.
+     * @param options [Hash] options for the decoder
+     *  max_recursion_depth: set to maximum decoding depth for message (default is 100)
      */
-    @JRubyMethod(meta = true)
-    public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject data) {
+    @JRubyMethod(required = 1, optional = 1, meta = true)
+    public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyObject[] args) {
+        IRubyObject data = args[0];
         byte[] bin = data.convertToString().getBytes();
+        CodedInputStream input = CodedInputStream.newInstance(bin);
         RubyMessage ret = (RubyMessage) ((RubyClass) recv).newInstance(context, Block.NULL_BLOCK);
+
+        if (args.length == 2) {
+            if (!(args[1] instanceof RubyHash)) {
+                throw context.runtime.newArgumentError("Expected hash arguments.");
+            }
+
+            IRubyObject maxRecursionDepth = ((RubyHash) args[1]).fastARef(context.runtime.newSymbol("max_recursion_depth"));
+            if (maxRecursionDepth != null) {
+                input.setRecursionLimit(((RubyNumeric) maxRecursionDepth).getIntValue());
+            }
+        }
+
         try {
-            ret.builder.mergeFrom(bin);
-        } catch (InvalidProtocolBufferException e) {
+            ret.builder.mergeFrom(input);
+        } catch (Exception e) {
             throw RaiseException.from(context.runtime, (RubyClass) context.runtime.getClassFromPath("Google::Protobuf::ParseError"), e.getMessage());
         }
 
@@ -541,7 +570,7 @@
         printer = printer.usingTypeRegistry(JsonFormat.TypeRegistry.newBuilder().add(message.descriptor).build());
 
         try {
-            result = printer.print(message.build(context));
+            result = printer.print(message.build(context, 0, SINK_MAXIMUM_NESTING));
         } catch (InvalidProtocolBufferException e) {
             throw runtime.newRuntimeError(e.getMessage());
         } catch (IllegalArgumentException e) {
@@ -635,12 +664,8 @@
         return ret;
     }
 
-    protected DynamicMessage build(ThreadContext context) {
-        return build(context, 0);
-    }
-
-    protected DynamicMessage build(ThreadContext context, int depth) {
-        if (depth > SINK_MAXIMUM_NESTING) {
+    protected DynamicMessage build(ThreadContext context, int depth, int maxRecursionDepth) {
+        if (depth >= maxRecursionDepth) {
             throw context.runtime.newRuntimeError("Maximum recursion depth exceeded during encoding.");
         }
 
@@ -651,7 +676,7 @@
             if (value instanceof RubyMap) {
                 builder.clearField(fieldDescriptor);
                 RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context, fieldDescriptor);
-                for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) {
+                for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
                     builder.addRepeatedField(fieldDescriptor, kv);
                 }
 
@@ -660,7 +685,7 @@
 
                 builder.clearField(fieldDescriptor);
                 for (int i = 0; i < repeatedField.size(); i++) {
-                    Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth,
+                    Object item = convert(context, fieldDescriptor, repeatedField.get(i), depth, maxRecursionDepth,
                         /*isDefaultValueForBytes*/ false);
                     builder.addRepeatedField(fieldDescriptor, item);
                 }
@@ -682,7 +707,7 @@
                     fieldDescriptor.getFullName().equals("google.protobuf.FieldDescriptorProto.default_value")) {
                     isDefaultStringForBytes = true;
                 }
-                builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, isDefaultStringForBytes));
+                builder.setField(fieldDescriptor, convert(context, fieldDescriptor, value, depth, maxRecursionDepth, isDefaultStringForBytes));
             }
         }
 
@@ -702,7 +727,7 @@
                 builder.clearField(fieldDescriptor);
                 RubyDescriptor mapDescriptor = (RubyDescriptor) getDescriptorForField(context,
                     fieldDescriptor);
-                for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth)) {
+                for (DynamicMessage kv : ((RubyMap) value).build(context, mapDescriptor, depth, maxRecursionDepth)) {
                     builder.addRepeatedField(fieldDescriptor, kv);
                 }
             }
@@ -814,7 +839,8 @@
     // convert a ruby object to protobuf type, skip type check since it is checked on the way in
     private Object convert(ThreadContext context,
                            FieldDescriptor fieldDescriptor,
-                           IRubyObject value, int depth, boolean isDefaultStringForBytes) {
+                           IRubyObject value, int depth, int maxRecursionDepth,
+                           boolean isDefaultStringForBytes) {
         Object val = null;
         switch (fieldDescriptor.getType()) {
             case INT32:
@@ -855,7 +881,7 @@
                 }
                 break;
             case MESSAGE:
-                val = ((RubyMessage) value).build(context, depth + 1);
+                val = ((RubyMessage) value).build(context, depth + 1, maxRecursionDepth);
                 break;
             case ENUM:
                 EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType();
@@ -1214,7 +1240,7 @@
     private static final String CONST_SUFFIX = "_const";
     private static final String HAS_PREFIX = "has_";
     private static final String QUESTION_MARK = "?";
-    private static final int SINK_MAXIMUM_NESTING = 63;
+    private static final int SINK_MAXIMUM_NESTING = 64;
 
     private Descriptor descriptor;
     private DynamicMessage.Builder builder;
diff --git a/ruby/tests/encode_decode_test.rb b/ruby/tests/encode_decode_test.rb
index 429ac43..9513cc3 100755
--- a/ruby/tests/encode_decode_test.rb
+++ b/ruby/tests/encode_decode_test.rb
@@ -101,4 +101,55 @@
     assert_match json, "{\"CustomJsonName\":42}"
   end
 
+  def test_decode_depth_limit
+    msg = A::B::C::TestMessage.new(
+      optional_msg: A::B::C::TestMessage.new(
+        optional_msg: A::B::C::TestMessage.new(
+          optional_msg: A::B::C::TestMessage.new(
+            optional_msg: A::B::C::TestMessage.new(
+              optional_msg: A::B::C::TestMessage.new(
+              )
+            )
+          )
+        )
+      )
+    )
+    msg_encoded = A::B::C::TestMessage.encode(msg)
+    msg_out = A::B::C::TestMessage.decode(msg_encoded)
+    assert_match msg.to_json, msg_out.to_json
+
+    assert_raise Google::Protobuf::ParseError do
+      A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 4 })
+    end
+
+    msg_out = A::B::C::TestMessage.decode(msg_encoded, { max_recursion_depth: 5 })
+    assert_match msg.to_json, msg_out.to_json
+  end
+
+  def test_encode_depth_limit
+    msg = A::B::C::TestMessage.new(
+      optional_msg: A::B::C::TestMessage.new(
+        optional_msg: A::B::C::TestMessage.new(
+          optional_msg: A::B::C::TestMessage.new(
+            optional_msg: A::B::C::TestMessage.new(
+              optional_msg: A::B::C::TestMessage.new(
+              )
+            )
+          )
+        )
+      )
+    )
+    msg_encoded = A::B::C::TestMessage.encode(msg)
+    msg_out = A::B::C::TestMessage.decode(msg_encoded)
+    assert_match msg.to_json, msg_out.to_json
+
+    assert_raise RuntimeError do
+      A::B::C::TestMessage.encode(msg, { max_recursion_depth: 5 })
+    end
+
+    msg_encoded = A::B::C::TestMessage.encode(msg, { max_recursion_depth: 6 })
+    msg_out = A::B::C::TestMessage.decode(msg_encoded)
+    assert_match msg.to_json, msg_out.to_json
+  end
+
 end