Go: Update generated wrapper functions for TensorFlow ops.
PiperOrigin-RevId: 354012686
Change-Id: I8cee7b8fab56608ae2aaeb43ea896505e0c898f4
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index b53e75f..64c61c1 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -33195,16 +33195,32 @@
return op.Output(0)
}
+// XlaShardingAttr is an optional argument to XlaSharding.
+type XlaShardingAttr func(optionalAttr)
+
+// XlaShardingSharding sets the optional sharding attribute to value.
+// If not specified, defaults to ""
+func XlaShardingSharding(value string) XlaShardingAttr {
+ return func(m optionalAttr) {
+ m["sharding"] = value
+ }
+}
+
// An op which shards the input based on the given sharding attribute.
-func XlaSharding(scope *Scope, input tf.Output) (output tf.Output) {
+func XlaSharding(scope *Scope, input tf.Output, optional ...XlaShardingAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
Type: "XlaSharding",
Input: []tf.Input{
input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)