Refactor logic to create and setup split operators to allow them to be reused by split3/split4 in the future
PiperOrigin-RevId: 433018910
diff --git a/src/subgraph/even-split2.c b/src/subgraph/even-split2.c
index 97c3e43..7ff03ef 100644
--- a/src/subgraph/even-split2.c
+++ b/src/subgraph/even-split2.c
@@ -10,6 +10,60 @@
#include <xnnpack/params.h>
#include <xnnpack/subgraph.h>
+static size_t calculate_batch_size(const struct xnn_value* input, size_t axis)
+{
+ size_t batch_size = 1;
+ for (size_t i = 0; i < axis; i++) {
+ batch_size *= input->shape.dim[i];
+ }
+ return batch_size;
+}
+
+static size_t calculate_input_stride(const struct xnn_value* input, size_t axis)
+{
+ size_t input_stride = 1;
+ for (size_t i = axis; i < input->shape.num_dims; i++) {
+ input_stride *= input->shape.dim[i];
+ }
+ return input_stride;
+}
+
+static enum xnn_status create_even_split_operator_helper(
+ const struct xnn_node* node,
+ size_t channels,
+ size_t input_stride,
+ size_t output_stride,
+ struct xnn_operator_data* opdata,
+ size_t index)
+{
+ switch (node->compute_type) {
+ #ifndef XNN_NO_F16_OPERATORS
+ case xnn_compute_type_fp16: {
+ return xnn_create_copy_nc_x16(
+ channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
+ }
+ #endif // !defined(XNN_NO_F16_OPERATORS)
+ case xnn_compute_type_fp32: {
+ return xnn_create_copy_nc_x32(
+ channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
+ }
+ #ifndef XNN_NO_QS8_OPERATORS
+ case xnn_compute_type_qs8:
+ #endif // !defined(XNN_NO_QS8_OPERATORS)
+ #ifndef XNN_NO_QU8_OPERATORS
+ case xnn_compute_type_qu8:
+ #endif // !defined(XNN_NO_QU8_OPERATORS)
+ #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+ {
+ return xnn_create_copy_nc_x8(
+ channels, input_stride, output_stride, node->flags, &opdata->operator_objects[index]);
+ }
+ #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+ default:
+ XNN_UNREACHABLE;
+ }
+}
+
static enum xnn_status create_even_split_operator(
const struct xnn_node* node,
const struct xnn_value* values,
@@ -31,66 +85,61 @@
assert(output2_id < num_values);
const size_t axis = node->params.even_split.axis;
- size_t batch_size = 1, channels = 1;
- for (size_t i = 0; i < axis; i++) {
- batch_size *= values[input_id].shape.dim[i];
- }
- for (size_t i = axis; i < values[input_id].shape.num_dims; i++) {
- channels *= values[input_id].shape.dim[i];
- }
- const size_t input_stride = channels;
- // Divide by 2 since we are splitting into 2 outputs.
- channels /= 2;
+ const size_t batch_size = calculate_batch_size(&values[input_id], axis);
+ const size_t input_stride = calculate_input_stride(&values[input_id], axis);
+ assert(input_stride % 2 == 0);
+ const size_t channels = input_stride / 2;
const size_t output_stride = channels;
enum xnn_status status;
- switch (node->compute_type) {
-#ifndef XNN_NO_F16_OPERATORS
- case xnn_compute_type_fp16: {
- status = xnn_create_copy_nc_x16(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[0]);
- if (status != xnn_status_success) {
- break;
+ status = create_even_split_operator_helper(node, channels, input_stride, output_stride, opdata, 0);
+ if (status != xnn_status_success) {
+ return status;
+ }
+ status = create_even_split_operator_helper(node, channels, input_stride, output_stride, opdata, 1);
+ if (status != xnn_status_success) {
+ return status;
+ }
+
+ opdata->inputs[0] = input_id;
+ opdata->outputs[0] = output1_id;
+ opdata->outputs[1] = output2_id;
+ opdata->batch_size = batch_size;
+
+ return status;
+}
+
+static enum xnn_status setup_even_split_operator_helper(
+ const size_t channels,
+ const void* input_data,
+ void* output_data,
+ const struct xnn_operator_data* opdata,
+ size_t index,
+ pthreadpool_t threadpool)
+{
+ switch (opdata->operator_objects[0]->type) {
+ #ifndef XNN_NO_F16_OPERATORS
+ case xnn_operator_type_copy_nc_x16: {
+ return xnn_setup_copy_nc_x16(
+ opdata->operator_objects[index], opdata->batch_size, (const uint16_t*) input_data + index * channels,
+ output_data, threadpool);
}
- status = xnn_create_copy_nc_x16(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[1]);
- break;
+ #endif // !defined(XNN_NO_F16_OPERATORS)
+ case xnn_operator_type_copy_nc_x32: {
+ return xnn_setup_copy_nc_x32(
+ opdata->operator_objects[index], opdata->batch_size, (const uint32_t*) input_data + index * channels,
+ output_data, threadpool);
}
-#endif // !defined(XNN_NO_F16_OPERATORS)
- case xnn_compute_type_fp32: {
- status = xnn_create_copy_nc_x32(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[0]);
- if (status != xnn_status_success) {
- break;
+ #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+ case xnn_operator_type_copy_nc_x8: {
+ return xnn_setup_copy_nc_x8(
+ opdata->operator_objects[index], opdata->batch_size, (const uint8_t*) input_data + index * channels,
+ output_data, threadpool);
}
- status = xnn_create_copy_nc_x32(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[1]);
- break;
- }
-#ifndef XNN_NO_QS8_OPERATORS
- case xnn_compute_type_qs8:
-#endif // !defined(XNN_NO_QS8_OPERATORS)
-#ifndef XNN_NO_QU8_OPERATORS
- case xnn_compute_type_qu8:
-#endif // !defined(XNN_NO_QU8_OPERATORS)
-#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
- {
- status = xnn_create_copy_nc_x8(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[0]);
- if (status != xnn_status_success) {
- break;
- }
- status = xnn_create_copy_nc_x8(channels, input_stride, output_stride, node->flags, &opdata->operator_objects[1]);
- break;
- }
-#endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
+ #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
default:
XNN_UNREACHABLE;
}
-
- if (status == xnn_status_success) {
- opdata->inputs[0] = input_id;
- opdata->outputs[0] = output1_id;
- opdata->outputs[1] = output2_id;
- opdata->batch_size = batch_size;
- }
-
- return status;
}
static enum xnn_status setup_even_split_operator(
@@ -123,50 +172,13 @@
void* output2_data = output2_blob->data;
assert(output2_data != NULL);
- enum xnn_status status;
- size_t channels = opdata->operator_objects[0]->channels;
+ const size_t channels = opdata->operator_objects[0]->channels;
- switch (opdata->operator_objects[0]->type) {
-#ifndef XNN_NO_F16_OPERATORS
- case xnn_operator_type_copy_nc_x16: {
- status =
- xnn_setup_copy_nc_x16(opdata->operator_objects[0], opdata->batch_size, input_data, output1_data, threadpool);
- if (status != xnn_status_success) {
- return status;
- }
- status = xnn_setup_copy_nc_x16(
- opdata->operator_objects[1], opdata->batch_size, (const uint16_t*) input_data + channels, output2_data,
- threadpool);
- return status;
- }
-#endif // !defined(XNN_NO_F16_OPERATORS)
- case xnn_operator_type_copy_nc_x32: {
- status =
- xnn_setup_copy_nc_x32(opdata->operator_objects[0], opdata->batch_size, input_data, output1_data, threadpool);
- if (status != xnn_status_success) {
- return status;
- }
- status = xnn_setup_copy_nc_x32(
- opdata->operator_objects[1], opdata->batch_size, (const uint32_t*) input_data + channels, output2_data,
- threadpool);
- return status;
- }
-#if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
- case xnn_operator_type_copy_nc_x8: {
- status =
- xnn_setup_copy_nc_x8(opdata->operator_objects[0], opdata->batch_size, input_data, output1_data, threadpool);
- if (status != xnn_status_success) {
- return status;
- }
- status = xnn_setup_copy_nc_x8(
- opdata->operator_objects[1], opdata->batch_size, (const uint8_t*) input_data + channels, output2_data,
- threadpool);
- return status;
- }
-#endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
- default:
- XNN_UNREACHABLE;
+ enum xnn_status status = setup_even_split_operator_helper(channels, input_data, output1_data, opdata, 0, threadpool);
+ if (status != xnn_status_success) {
+ return status;
}
+ return setup_even_split_operator_helper(channels, input_data, output2_data, opdata, 1, threadpool);
}
enum xnn_status xnn_define_even_split2(