blob: 51f3878f9190d3a0520607808b1f46e05fe23d33 [file] [log] [blame]
/*
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package internal
import (
"bytes"
"go/format"
"testing"
"github.com/golang/protobuf/proto"
adpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework/api_def_go_proto"
odpb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto"
)
// Creates an ApiDef based on opdef and applies overrides
// from apidefText (ApiDef text proto).
func GetAPIDef(t *testing.T, opdef *odpb.OpDef, apidefText string) *adpb.ApiDef {
opdefList := &odpb.OpList{Op: []*odpb.OpDef{opdef}}
apimap, err := newAPIDefMap(opdefList)
if err != nil {
t.Fatal(err)
}
err = apimap.Put(apidefText)
if err != nil {
t.Fatal(err)
}
apidef, err := apimap.Get(opdef.Name)
if err != nil {
t.Fatal(err)
}
return apidef
}
func TestGenerateOp(t *testing.T) {
// TestGenerateOp validates the generated source code for an op.
// The OpDef for the test cases are simplified forms of real ops.
testdata := []struct {
tag string
opdef string
apidef string
wanted string
}{
{
tag: "NoOp",
opdef: `
name: "NoOp"
`,
apidef: `
op: <
graph_op_name: "NoOp"
summary: "No. Op."
>
`,
wanted: `
// No. Op.
//
// Returns the created operation.
func NoOp(scope *Scope) (o *tf.Operation) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "NoOp",
}
return scope.AddOperation(opspec)
}
`,
},
{
tag: "NoAttributes",
opdef: `
name: "Add"
input_arg: <
name: "x"
type_attr: "T"
>
input_arg: <
name: "y"
type_attr: "T"
>
output_arg: <
name: "z"
type_attr: "T"
>
attr: <
name: "T"
type: "type"
allowed_values: <
list: <
type: DT_FLOAT
type: DT_INT64
>
>
>
`,
apidef: `
op: <
graph_op_name: "Add"
summary: "Returns x + y element-wise."
description: "Blah blah",
>
`,
wanted: `
// Returns x + y element-wise.
//
// Blah blah
func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "Add",
Input: []tf.Input{
x, y,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
`,
},
{
tag: "RequiredAttributes",
opdef: `
name: "Cast"
input_arg: <
name: "x"
type_attr: "SrcT"
>
output_arg: <
name: "y"
type_attr: "DstT"
>
attr: <
name: "SrcT"
type: "type"
>
attr: <
name: "DstT"
type: "type"
>
`,
apidef: `
op: <
graph_op_name: "Cast"
summary: "Cast x of type SrcT to y of DstT."
>
`,
wanted: `
// Cast x of type SrcT to y of DstT.
func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"DstT": DstT}
opspec := tf.OpSpec{
Type: "Cast",
Input: []tf.Input{
x,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
`,
},
{
tag: "OptionalAttributes",
opdef: `
name: "DecodeJpeg"
input_arg: <
name: "contents"
type: DT_STRING
>
output_arg: <
name: "image"
type: DT_UINT8
>
attr: <
name: "channels"
type: "int"
default_value: <
i: 0
>
>
attr: <
name: "fancy_upscaling"
type: "bool"
default_value: <
b: true
>
>
attr: <
name: "acceptable_fraction"
type: "float"
default_value: <
f: 1
>
>
`,
apidef: `
op: <
graph_op_name: "DecodeJpeg"
in_arg: <
name: "contents"
description: "0-D. The JPEG-encoded image."
>
out_arg: <
name: "image"
description: "3-D with shape [height, width, channels]"
>
attr: <
name: "channels"
description: "Number of color channels for the decoded image."
>
attr: <
name: "fancy_upscaling"
description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
>
attr: <
name: "acceptable_fraction"
description: "The minimum required fraction of lines before a truncated\ninput is accepted."
>
summary: "Decode a JPEG-encoded image to a uint8 tensor."
description: "Norna dorna fjord\nkajorna\nhahaha"
>
`,
wanted: `
// DecodeJpegAttr is an optional argument to DecodeJpeg.
type DecodeJpegAttr func(optionalAttr)
// DecodeJpegChannels sets the optional channels attribute to value.
//
// value: Number of color channels for the decoded image.
// If not specified, defaults to 0
func DecodeJpegChannels(value int64) DecodeJpegAttr {
return func(m optionalAttr) {
m["channels"] = value
}
}
// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value.
//
// value: If true use a slower but nicer upscaling of the
// chroma planes (yuv420/422 only).
// If not specified, defaults to true
func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr {
return func(m optionalAttr) {
m["fancy_upscaling"] = value
}
}
// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value.
//
// value: The minimum required fraction of lines before a truncated
// input is accepted.
// If not specified, defaults to 1
func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr {
return func(m optionalAttr) {
m["acceptable_fraction"] = value
}
}
// Decode a JPEG-encoded image to a uint8 tensor.
//
// Norna dorna fjord
// kajorna
// hahaha
//
// Arguments:
// contents: 0-D. The JPEG-encoded image.
//
// Returns 3-D with shape [height, width, channels]
func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "DecodeJpeg",
Input: []tf.Input{
contents,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
`,
},
{
tag: "MultipleOutputs",
opdef: `
name: "TwoOutputs"
input_arg: <
name: "input"
type_attr: "T"
>
output_arg <
name: "x"
type_attr: "T"
>
output_arg <
name: "y"
type_attr: "T"
>
attr: <
name: "T"
type: "type"
>
`,
apidef: `
op: <
graph_op_name: "TwoOutputs"
summary: "Op that produces multiple outputs"
>
`,
wanted: `
// Op that produces multiple outputs
func TwoOutputs(scope *Scope, input tf.Output) (x tf.Output, y tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "TwoOutputs",
Input: []tf.Input{
input,
},
}
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
`,
},
{
tag: "ListOutput",
opdef: `
name: "ShapeN"
input_arg: <
name: "input"
type_attr: "T"
number_attr: "N"
>
output_arg: <
name: "output"
type_attr: "out_type"
number_attr: "N"
>
attr: <
name: "N"
type: "int"
has_minimum: true
minimum: 1
>
attr: <
name: "T"
type: "type"
>
attr: <
name: "out_type"
type: "type"
default_value: <
type: DT_INT32
>
allowed_values: <
list: <
type: DT_INT32
type: DT_INT64
>
>
>
`,
apidef: `
op: <
graph_op_name: "ShapeN"
summary: "Returns shape of tensors."
description: "Some description here."
>
`,
wanted: `
// ShapeNAttr is an optional argument to ShapeN.
type ShapeNAttr func(optionalAttr)
// ShapeNOutType sets the optional out_type attribute to value.
// If not specified, defaults to DT_INT32
func ShapeNOutType(value tf.DataType) ShapeNAttr {
return func(m optionalAttr) {
m["out_type"] = value
}
}
// Returns shape of tensors.
//
// Some description here.
func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "ShapeN",
Input: []tf.Input{
tf.OutputList(input),
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
if scope.Err() != nil {
return
}
var idx int
var err error
if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
scope.UpdateErr("ShapeN", err)
return
}
return output
}
`,
},
{
tag: "ApiDefOverrides",
opdef: `
name: "TestOp"
input_arg: <
name: "a"
type: DT_STRING
>
input_arg: <
name: "b"
type: DT_STRING
>
output_arg: <
name: "c"
type: DT_UINT8
>
attr: <
name: "d"
type: "int"
default_value: <
i: 0
>
>
`,
apidef: `
op: <
graph_op_name: "TestOp"
in_arg: <
name: "a"
rename_to: "aa"
description: "Description for aa."
>
in_arg: <
name: "b"
rename_to: "bb"
description: "Description for bb."
>
arg_order: "b"
arg_order: "a"
out_arg: <
name: "c"
rename_to: "cc"
description: "Description for cc."
>
attr: <
name: "d"
rename_to: "dd"
description: "Description for dd."
>
summary: "Summary for TestOp."
description: "Description for TestOp."
>
`,
wanted: `
// TestOpAttr is an optional argument to TestOp.
type TestOpAttr func(optionalAttr)
// TestOpDd sets the optional dd attribute to value.
//
// value: Description for dd.
// If not specified, defaults to 0
func TestOpDd(value int64) TestOpAttr {
return func(m optionalAttr) {
m["d"] = value
}
}
// Summary for TestOp.
//
// Description for TestOp.
//
// Arguments:
// bb: Description for bb.
// aa: Description for aa.
//
// Returns Description for cc.
func TestOp(scope *Scope, bb tf.Output, aa tf.Output, optional ...TestOpAttr) (cc tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "TestOp",
Input: []tf.Input{
aa, bb,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
`,
},
}
for _, test := range testdata {
t.Run(test.tag, func(t *testing.T) {
var opdef odpb.OpDef
var apidef *adpb.ApiDef
var buf bytes.Buffer
if err := proto.UnmarshalText(test.opdef, &opdef); err != nil {
t.Fatal(err)
}
apidef = GetAPIDef(t, &opdef, test.apidef)
if err := generateFunctionForOp(&buf, &opdef, apidef); err != nil {
t.Fatal(err)
}
got, err := format.Source(buf.Bytes())
if err != nil {
t.Fatalf("Unable to format: %v\n%s", err, buf.Bytes())
}
want, err := format.Source([]byte(test.wanted))
if err != nil {
t.Fatalf("Unable to format: %v\n%s", err, test.wanted)
}
if !bytes.Equal(got, want) {
t.Fatalf("Got:\n%s\nWant:\n%s\n", got, want)
}
})
}
}