blob: f912c9d38003c12e4075c2d972336407850f6027 [file] [log] [blame]
// Copyright (C) 2015 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package template
import (
"bytes"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"text/template"
"android.googlesource.com/platform/tools/gpu/api"
"android.googlesource.com/platform/tools/gpu/api/apic/commands"
"android.googlesource.com/platform/tools/gpu/api/resolver"
"android.googlesource.com/platform/tools/gpu/tools/copyright"
)
var (
command = &commands.Command{
Name: "template",
ShortHelp: "Passes the ast to a template for code generation",
}
dir = command.Flags.String("dir", cwd(), "The output directory")
tracer = command.Flags.String("t", "", "The template function trace expression")
deps = command.Flags.String("deps", "", "The dependancies file to generate")
)
func init() {
command.Flags.Var(&globalList, "G", "A global value setting for the template")
command.Run = doTemplate
commands.Register(command)
}
func cwd() string {
p, _ := os.Getwd()
return p
}
var (
inputs []string
outputs []string
)
func inputDep(name string) {
path, _ := filepath.Abs(name)
inputs = append(inputs, path)
}
func outputDep(name string) {
path, _ := filepath.Abs(name)
outputs = append(outputs, path)
}
func writeDeps() error {
if len(*deps) == 0 {
return nil
}
commands.Logf("Write deps to %v\n", *deps)
file, err := os.Create(*deps)
if err != nil {
return err
}
fmt.Fprintln(file, "==Inputs==")
for _, entry := range inputs {
fmt.Fprintln(file, entry)
}
fmt.Fprintln(file, "==Outputs==")
for _, entry := range outputs {
fmt.Fprintln(file, entry)
}
return file.Close()
}
func (f *Functions) execute(active *template.Template, writer io.Writer, data interface{}) error {
olda := f.active
oldw := f.writer
f.active = active
if writer != nil {
f.writer = writer
}
defer func() {
f.active = olda
f.writer = oldw
}()
return f.active.Execute(f.writer, data)
}
// Include loads each of the templates and executes their main bodies.
// The filenames are relative to the template doing the include.
func (f *Functions) Include(templates ...string) error {
dir := ""
if f.active != nil {
dir = filepath.Dir(f.active.Name())
}
for _, t := range templates {
if dir != "" {
t = filepath.Join(dir, t)
}
if f.templates.Lookup(t) == nil {
commands.Logf("Reading template %q\n", t)
inputDep(t)
tmplData, err := f.loader(t)
commands.MaybeError(t, err)
tmpl, err := f.templates.New(t).Parse(string(tmplData))
commands.MaybeError(t, err)
commands.Logf("Executing template %q\n", tmpl.Name())
var buf bytes.Buffer
commands.MaybeError(tmpl.Name(), f.execute(tmpl, &buf, f.api))
}
}
return nil
}
// Write takes a string and writes it into the specified file.
// The filename is relative to the output directory.
func (f *Functions) Write(fileName string, value string) (string, error) {
outputPath := filepath.Join(f.basePath, fileName)
commands.Logf("Writing output to %q\n", outputPath)
outputDep(outputPath)
return "", ioutil.WriteFile(outputPath, []byte(value), 0666)
}
// Copyright emits the copyright header specified by name with the «Tool» set to tool.
func (f *Functions) Copyright(name string, tool string) (string, error) {
return copyright.Build(name, copyright.Info{Year: "2015", Tool: tool}), nil
}
func doTemplate(flags flag.FlagSet) {
args := flags.Args()
if len(args) < 1 {
commands.Usage("Missing api file\n")
}
apiName := args[0]
if len(args) < 2 {
commands.Usage("Missing template file\n")
}
mainTemplate := args[1]
commands.Logf("Reading api file %q\n", apiName)
inputDep(apiName)
commands.Logf("Compiling api file %q\n", apiName)
mappings := resolver.ASTToSemantic{}
compiled, errs := api.Resolve(apiName, mappings)
commands.CheckErrors(apiName, errs)
f := NewFunctions(apiName, compiled, ioutil.ReadFile, nil)
commands.MaybeError(mainTemplate, f.Include(mainTemplate))
writeDeps()
}