Refactor logic to create and setup split operators to allow them to be reused by split3/split4 in the future

PiperOrigin-RevId: 433018910
diff --git a/src/subgraph/even-split2.c b/src/subgraph/even-split2.c
index 97c3e43..7ff03ef 100644
--- a/src/subgraph/even-split2.c
+++ b/src/subgraph/even-split2.c
@@ -10,6 +10,60 @@
 #include <xnnpack/params.h>
 #include <xnnpack/subgraph.h>
 
+static size_t calculate_batch_size(const struct xnn_value* input, size_t axis)
+{
+  size_t batch_size = 1;
+  for (size_t i = 0; i < axis; i++) {
+    batch_size *= input->shape.dim[i];
+  }
+  return batch_size;
+}
+
+static size_t calculate_input_stride(const struct xnn_value* input, size_t axis)
+{
+  size_t input_stride = 1;
+  for (size_t i = axis; i < input->shape.num_dims; i++) {
+    input_stride *= input->shape.dim[i];
+  }
+  return input_stride;
+}
+
+static enum xnn_status create_even_split_operator_helper(
+    const struct xnn_node* node,
+    size_t channels,
+    size_t input_stride,
+    size_t output_stride,
+    struct xnn_operator_data* opdata,
+    size_t index)
+{
+  switch (node->compute_type) {
+    #ifndef XNN_NO_F16_OPERATORS
+      case xnn_compute_type_fp16: {
+        return xnn_create_copy_nc_x16(
+            channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
+      }
+    #endif  // !defined(XNN_NO_F16_OPERATORS)
+    case xnn_compute_type_fp32: {
+      return xnn_create_copy_nc_x32(
+          channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
+    }
+    #ifndef XNN_NO_QS8_OPERATORS
+      case xnn_compute_type_qs8:
+    #endif  // !defined(XNN_NO_QS8_OPERATORS)
+    #ifndef XNN_NO_QU8_OPERATORS
+      case xnn_compute_type_qu8:
+    #endif  // !defined(XNN_NO_QU8_OPERATORS)
+    #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+      {
+        return xnn_create_copy_nc_x8(
+            channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
+      }
+    #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+    default:
+      XNN_UNREACHABLE;
+  }
+}
+
 static enum xnn_status create_even_split_operator(
   const struct xnn_node* node,
   const struct xnn_value* values,
@@ -31,66 +85,61 @@
   assert(output2_id < num_values);
 
   const size_t axis = node->params.even_split.axis;
-  size_t batch_size = 1, channels = 1;
-  for (size_t i = 0; i < axis; i++) {
-    batch_size *= values[input_id].shape.dim[i];
-  }
-  for (size_t i = axis; i < values[input_id].shape.num_dims; i++) {
-    channels *= values[input_id].shape.dim[i];
-  }
-  const size_t input_stride = channels;
-  // Divide by 2 since we are splitting into 2 outputs.
-  channels /= 2;
+  const size_t batch_size = calculate_batch_size(&values[input_id], axis);
+  const size_t input_stride = calculate_input_stride(&values[input_id], axis);
+  assert(input_stride % 2 == 0);
+  const size_t channels = input_stride / 2;
   const size_t output_stride = channels;
 
   enum xnn_status status;
-  switch (node->compute_type) {
-#ifndef XNN_NO_F16_OPERATORS
-    case xnn_compute_type_fp16: {
-      status = xnn_create_copy_nc_x16(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[0]);
-      if (status != xnn_status_success) {
-        break;
+  status = create_even_split_operator_helper(node, channels, input_stride, output_stride, opdata, 0);
+  if (status != xnn_status_success) {
+    return status;
+  }
+  status = create_even_split_operator_helper(node, channels, input_stride, output_stride, opdata, 1);
+  if (status != xnn_status_success) {
+    return status;
+  }
+
+  opdata->inputs[0] = input_id;
+  opdata->outputs[0] = output1_id;
+  opdata->outputs[1] = output2_id;
+  opdata->batch_size = batch_size;
+
+  return status;
+}
+
+static enum xnn_status setup_even_split_operator_helper(
+    const size_t channels,
+    const void* input_data,
+    void* output_data,
+    const struct xnn_operator_data* opdata,
+    size_t index,
+    pthreadpool_t threadpool)
+{
+  switch (opdata->operator_objects[0]->type) {
+    #ifndef XNN_NO_F16_OPERATORS
+      case xnn_operator_type_copy_nc_x16: {
+        return xnn_setup_copy_nc_x16(
+            opdata->operator_objects[index], opdata->batch_size, (const uint16_t*) input_data + index * channels,
+            output_data, threadpool);
       }
-      status = xnn_create_copy_nc_x16(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[1]);
-      break;
+    #endif  // !defined(XNN_NO_F16_OPERATORS)
+    case xnn_operator_type_copy_nc_x32: {
+      return xnn_setup_copy_nc_x32(
+          opdata->operator_objects[index], opdata->batch_size, (const uint32_t*) input_data + index * channels,
+          output_data, threadpool);
     }
-#endif  // !defined(XNN_NO_F16_OPERATORS)
-    case xnn_compute_type_fp32: {
-      status = xnn_create_copy_nc_x32(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[0]);
-      if (status != xnn_status_success) {
-        break;
+    #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+      case xnn_operator_type_copy_nc_x8: {
+        return xnn_setup_copy_nc_x8(
+            opdata->operator_objects[index], opdata->batch_size, (const uint8_t*) input_data + index * channels,
+            output_data, threadpool);
       }
-      status = xnn_create_copy_nc_x32(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[1]);
-      break;
-    }
-#ifndef XNN_NO_QS8_OPERATORS
-    case xnn_compute_type_qs8:
-#endif  // !defined(XNN_NO_QS8_OPERATORS)
-#ifndef XNN_NO_QU8_OPERATORS
-    case xnn_compute_type_qu8:
-#endif  // !defined(XNN_NO_QU8_OPERATORS)
-#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
-    {
-      status = xnn_create_copy_nc_x8(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[0]);
-      if (status != xnn_status_success) {
-        break;
-      }
-      status = xnn_create_copy_nc_x8(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[1]);
-      break;
-    }
-#endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+    #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
     default:
       XNN_UNREACHABLE;
   }
-
-  if (status == xnn_status_success) {
-    opdata->inputs[0] = input_id;
-    opdata->outputs[0] = output1_id;
-    opdata->outputs[1] = output2_id;
-    opdata->batch_size = batch_size;
-  }
-
-  return status;
 }
 
 static enum xnn_status setup_even_split_operator(
@@ -123,50 +172,13 @@
   void* output2_data = output2_blob->data;
   assert(output2_data != NULL);
 
-  enum xnn_status status;
-  size_t channels = opdata->operator_objects[0]->channels;
+  const size_t channels = opdata->operator_objects[0]->channels;
 
-  switch (opdata->operator_objects[0]->type) {
-#ifndef XNN_NO_F16_OPERATORS
-    case xnn_operator_type_copy_nc_x16: {
-      status =
-        xnn_setup_copy_nc_x16(opdata->operator_objects[0], opdata->batch_size, input_data, output1_data, threadpool);
-      if (status != xnn_status_success) {
-        return status;
-      }
-      status = xnn_setup_copy_nc_x16(
-        opdata->operator_objects[1], opdata->batch_size, (const uint16_t*) input_data + channels, output2_data,
-        threadpool);
-      return status;
-    }
-#endif  // !defined(XNN_NO_F16_OPERATORS)
-    case xnn_operator_type_copy_nc_x32: {
-      status =
-        xnn_setup_copy_nc_x32(opdata->operator_objects[0], opdata->batch_size, input_data, output1_data, threadpool);
-      if (status != xnn_status_success) {
-        return status;
-      }
-      status = xnn_setup_copy_nc_x32(
-        opdata->operator_objects[1], opdata->batch_size, (const uint32_t*) input_data + channels, output2_data,
-        threadpool);
-      return status;
-    }
-#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
-    case xnn_operator_type_copy_nc_x8: {
-      status =
-        xnn_setup_copy_nc_x8(opdata->operator_objects[0], opdata->batch_size, input_data, output1_data, threadpool);
-      if (status != xnn_status_success) {
-        return status;
-      }
-      status = xnn_setup_copy_nc_x8(
-        opdata->operator_objects[1], opdata->batch_size, (const uint8_t*) input_data + channels, output2_data,
-        threadpool);
-      return status;
-    }
-#endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
-    default:
-      XNN_UNREACHABLE;
+  enum xnn_status status = setup_even_split_operator_helper(channels, input_data, output1_data, opdata, 0, threadpool);
+  if (status != xnn_status_success) {
+    return status;
   }
+  return setup_even_split_operator_helper(channels, input_data, output2_data, opdata, 1, threadpool);
 }
 
 enum xnn_status xnn_define_even_split2(