Snap for 6198741 from f0e6d47072e7b5714f23b97473afb03e61a2ee8b to sdk-release
Change-Id: Ia0c9276f1a11fb55c4513eaa2956d407fe2b57ec
diff --git a/Blueprints b/Blueprints
index c3c8975..ecc0792 100644
--- a/Blueprints
+++ b/Blueprints
@@ -19,15 +19,14 @@
"package_ctx.go",
"scope.go",
"singleton_ctx.go",
- "unpack.go",
],
testSrcs: [
"context_test.go",
"glob_test.go",
+ "module_ctx_test.go",
"ninja_strings_test.go",
"ninja_writer_test.go",
"splice_modules_test.go",
- "unpack_test.go",
"visit_test.go",
],
}
@@ -75,6 +74,9 @@
bootstrap_go_package {
name: "blueprint-proptools",
pkgPath: "github.com/google/blueprint/proptools",
+ deps: [
+ "blueprint-parser",
+ ],
srcs: [
"proptools/clone.go",
"proptools/escape.go",
@@ -83,6 +85,7 @@
"proptools/proptools.go",
"proptools/tag.go",
"proptools/typeequal.go",
+ "proptools/unpack.go",
],
testSrcs: [
"proptools/clone_test.go",
@@ -91,6 +94,7 @@
"proptools/filter_test.go",
"proptools/tag_test.go",
"proptools/typeequal_test.go",
+ "proptools/unpack_test.go",
],
}
diff --git a/bootstrap/bootstrap.go b/bootstrap/bootstrap.go
index 1799e68..bb85e7d 100644
--- a/bootstrap/bootstrap.go
+++ b/bootstrap/bootstrap.go
@@ -27,6 +27,7 @@
"github.com/google/blueprint/pathtools"
)
+const mainSubDir = ".primary"
const bootstrapSubDir = ".bootstrap"
const miniBootstrapSubDir = ".minibootstrap"
@@ -143,15 +144,16 @@
"depfile")
_ = pctx.VariableFunc("BinDir", func(config interface{}) (string, error) {
- return binDir(), nil
+ return bootstrapBinDir(), nil
})
_ = pctx.VariableFunc("ToolDir", func(config interface{}) (string, error) {
return toolDir(config), nil
})
- docsDir = filepath.Join(bootstrapDir, "docs")
+ docsDir = filepath.Join(mainDir, "docs")
+ mainDir = filepath.Join("$buildDir", mainSubDir)
bootstrapDir = filepath.Join("$buildDir", bootstrapSubDir)
miniBootstrapDir = filepath.Join("$buildDir", miniBootstrapSubDir)
@@ -165,7 +167,7 @@
isGoBinary()
}
-func binDir() string {
+func bootstrapBinDir() string {
return filepath.Join(BuildDir, bootstrapSubDir, "bin")
}
@@ -307,14 +309,14 @@
return
}
- g.pkgRoot = packageRoot(ctx)
+ g.pkgRoot = packageRoot(ctx, g.config)
g.archiveFile = filepath.Join(g.pkgRoot,
filepath.FromSlash(g.properties.PkgPath)+".a")
ctx.VisitDepsDepthFirstIf(isGoPluginFor(name),
func(module blueprint.Module) { hasPlugins = true })
if hasPlugins {
- pluginSrc = filepath.Join(moduleGenSrcDir(ctx), "plugin.go")
+ pluginSrc = filepath.Join(moduleGenSrcDir(ctx, g.config), "plugin.go")
genSrcs = append(genSrcs, pluginSrc)
}
@@ -332,9 +334,9 @@
}
if g.config.runGoTests {
- testArchiveFile := filepath.Join(testRoot(ctx),
+ testArchiveFile := filepath.Join(testRoot(ctx, g.config),
filepath.FromSlash(g.properties.PkgPath)+".a")
- g.testResultFile = buildGoTest(ctx, testRoot(ctx), testArchiveFile,
+ g.testResultFile = buildGoTest(ctx, testRoot(ctx, g.config), testArchiveFile,
g.properties.PkgPath, srcs, genSrcs,
testSrcs)
}
@@ -395,9 +397,9 @@
func (g *goBinary) GenerateBuildActions(ctx blueprint.ModuleContext) {
var (
name = ctx.ModuleName()
- objDir = moduleObjDir(ctx)
+ objDir = moduleObjDir(ctx, g.config)
archiveFile = filepath.Join(objDir, name+".a")
- testArchiveFile = filepath.Join(testRoot(ctx), name+".a")
+ testArchiveFile = filepath.Join(testRoot(ctx, g.config), name+".a")
aoutFile = filepath.Join(objDir, "a.out")
hasPlugins = false
pluginSrc = ""
@@ -406,14 +408,16 @@
if g.properties.Tool_dir {
g.installPath = filepath.Join(toolDir(ctx.Config()), name)
+ } else if g.config.stage == StageMain {
+ g.installPath = filepath.Join(mainDir, "bin", name)
} else {
- g.installPath = filepath.Join(binDir(), name)
+ g.installPath = filepath.Join(bootstrapDir, "bin", name)
}
ctx.VisitDepsDepthFirstIf(isGoPluginFor(name),
func(module blueprint.Module) { hasPlugins = true })
if hasPlugins {
- pluginSrc = filepath.Join(moduleGenSrcDir(ctx), "plugin.go")
+ pluginSrc = filepath.Join(moduleGenSrcDir(ctx, g.config), "plugin.go")
genSrcs = append(genSrcs, pluginSrc)
}
@@ -433,7 +437,7 @@
}
if g.config.runGoTests {
- deps = buildGoTest(ctx, testRoot(ctx), testArchiveFile,
+ deps = buildGoTest(ctx, testRoot(ctx, g.config), testArchiveFile,
name, srcs, genSrcs, testSrcs)
}
@@ -687,7 +691,7 @@
if s.config.stage == StagePrimary {
mainNinjaFile := filepath.Join("$buildDir", "build.ninja")
- primaryBuilderNinjaGlobFile := filepath.Join(BuildDir, bootstrapSubDir, "build-globs.ninja")
+ primaryBuilderNinjaGlobFile := absolutePath(filepath.Join(BuildDir, bootstrapSubDir, "build-globs.ninja"))
if _, err := os.Stat(primaryBuilderNinjaGlobFile); os.IsNotExist(err) {
err = ioutil.WriteFile(primaryBuilderNinjaGlobFile, nil, 0666)
@@ -758,18 +762,26 @@
}
}
+func stageDir(config *Config) string {
+ if config.stage == StageMain {
+ return mainDir
+ } else {
+ return bootstrapDir
+ }
+}
+
// packageRoot returns the module-specific package root directory path. This
// directory is where the final package .a files are output and where dependant
// modules search for this package via -I arguments.
-func packageRoot(ctx blueprint.ModuleContext) string {
- return filepath.Join(bootstrapDir, ctx.ModuleName(), "pkg")
+func packageRoot(ctx blueprint.ModuleContext, config *Config) string {
+ return filepath.Join(stageDir(config), ctx.ModuleName(), "pkg")
}
// testRoot returns the module-specific package root directory path used for
// building tests. The .a files generated here will include everything from
// packageRoot, plus the test-only code.
-func testRoot(ctx blueprint.ModuleContext) string {
- return filepath.Join(bootstrapDir, ctx.ModuleName(), "test")
+func testRoot(ctx blueprint.ModuleContext, config *Config) string {
+ return filepath.Join(stageDir(config), ctx.ModuleName(), "test")
}
// moduleSrcDir returns the path of the directory that all source file paths are
@@ -779,11 +791,11 @@
}
// moduleObjDir returns the module-specific object directory path.
-func moduleObjDir(ctx blueprint.ModuleContext) string {
- return filepath.Join(bootstrapDir, ctx.ModuleName(), "obj")
+func moduleObjDir(ctx blueprint.ModuleContext, config *Config) string {
+ return filepath.Join(stageDir(config), ctx.ModuleName(), "obj")
}
// moduleGenSrcDir returns the module-specific generated sources path.
-func moduleGenSrcDir(ctx blueprint.ModuleContext) string {
- return filepath.Join(bootstrapDir, ctx.ModuleName(), "gen")
+func moduleGenSrcDir(ctx blueprint.ModuleContext, config *Config) string {
+ return filepath.Join(stageDir(config), ctx.ModuleName(), "gen")
}
diff --git a/bootstrap/bpdoc/bpdoc.go b/bootstrap/bpdoc/bpdoc.go
index 4acfc5d..4abf2e7 100644
--- a/bootstrap/bpdoc/bpdoc.go
+++ b/bootstrap/bpdoc/bpdoc.go
@@ -6,7 +6,6 @@
"reflect"
"sort"
- "github.com/google/blueprint"
"github.com/google/blueprint/proptools"
)
@@ -146,14 +145,6 @@
return nil, fmt.Errorf("nesting point %q not found", nestedName)
}
- key, value, err := blueprint.HasFilter(nestPoint.Tag)
- if err != nil {
- return nil, err
- }
- if key != "" {
- nested.IncludeByTag(key, value)
- }
-
nestPoint.Nest(nested)
}
mt.PropertyStructs = append(mt.PropertyStructs, ps)
diff --git a/bootstrap/cleanup.go b/bootstrap/cleanup.go
index 4a8ce25..6444081 100644
--- a/bootstrap/cleanup.go
+++ b/bootstrap/cleanup.go
@@ -70,7 +70,7 @@
for _, filePath := range filePaths {
isTarget := targets[filePath]
if !isTarget {
- err = removeFileAndEmptyDirs(filePath)
+ err = removeFileAndEmptyDirs(absolutePath(filePath))
if err != nil {
return err
}
diff --git a/bootstrap/command.go b/bootstrap/command.go
index bf6bbe9..1e3b2fe 100644
--- a/bootstrap/command.go
+++ b/bootstrap/command.go
@@ -47,6 +47,8 @@
BuildDir string
NinjaBuildDir string
SrcDir string
+
+ absSrcDir string
)
func init() {
@@ -76,8 +78,10 @@
debug.SetGCPercent(-1)
}
+ absSrcDir = ctx.SrcDir()
+
if cpuprofile != "" {
- f, err := os.Create(cpuprofile)
+ f, err := os.Create(absolutePath(cpuprofile))
if err != nil {
fatalf("error opening cpuprofile: %s", err)
}
@@ -87,7 +91,7 @@
}
if traceFile != "" {
- f, err := os.Create(traceFile)
+ f, err := os.Create(absolutePath(traceFile))
if err != nil {
fatalf("error opening trace: %s", err)
}
@@ -140,7 +144,7 @@
ctx.RegisterSingletonType("glob", globSingletonFactory(ctx))
- deps, errs := ctx.ParseFileList(filepath.Dir(bootstrapConfig.topLevelBlueprintsFile), filesToParse)
+ deps, errs := ctx.ParseFileList(filepath.Dir(bootstrapConfig.topLevelBlueprintsFile), filesToParse, config)
if len(errs) > 0 {
fatalErrors(errs)
}
@@ -155,7 +159,7 @@
deps = append(deps, extraDeps...)
if docFile != "" {
- err := writeDocs(ctx, docFile)
+ err := writeDocs(ctx, absolutePath(docFile))
if err != nil {
fatalErrors([]error{err})
}
@@ -180,7 +184,7 @@
var buf *bufio.Writer
if stage != StageMain || !emptyNinjaFile {
- f, err = os.OpenFile(outFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, outFilePermissions)
+ f, err = os.OpenFile(absolutePath(outFile), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, outFilePermissions)
if err != nil {
fatalf("error opening Ninja file: %s", err)
}
@@ -190,6 +194,25 @@
out = ioutil.Discard
}
+ if globFile != "" {
+ buffer, errs := generateGlobNinjaFile(ctx.Globs)
+ if len(errs) > 0 {
+ fatalErrors(errs)
+ }
+
+ err = ioutil.WriteFile(absolutePath(globFile), buffer, outFilePermissions)
+ if err != nil {
+ fatalf("error writing %s: %s", globFile, err)
+ }
+ }
+
+ if depFile != "" {
+ err := deptools.WriteDepFile(absolutePath(depFile), outFile, deps)
+ if err != nil {
+ fatalf("error writing depfile: %s", err)
+ }
+ }
+
err = ctx.WriteBuildFile(out)
if err != nil {
fatalf("error writing Ninja file contents: %s", err)
@@ -209,25 +232,6 @@
}
}
- if globFile != "" {
- buffer, errs := generateGlobNinjaFile(ctx.Globs)
- if len(errs) > 0 {
- fatalErrors(errs)
- }
-
- err = ioutil.WriteFile(globFile, buffer, outFilePermissions)
- if err != nil {
- fatalf("error writing %s: %s", outFile, err)
- }
- }
-
- if depFile != "" {
- err := deptools.WriteDepFile(depFile, outFile, deps)
- if err != nil {
- fatalf("error writing depfile: %s", err)
- }
- }
-
if c, ok := config.(ConfigRemoveAbandonedFilesUnder); ok {
under, except := c.RemoveAbandonedFilesUnder()
err := removeAbandonedFilesUnder(ctx, bootstrapConfig, SrcDir, under, except)
@@ -237,7 +241,7 @@
}
if memprofile != "" {
- f, err := os.Create(memprofile)
+ f, err := os.Create(absolutePath(memprofile))
if err != nil {
fatalf("error opening memprofile: %s", err)
}
@@ -268,3 +272,10 @@
}
os.Exit(1)
}
+
+func absolutePath(path string) string {
+ if filepath.IsAbs(path) {
+ return path
+ }
+ return filepath.Join(absSrcDir, path)
+}
diff --git a/bootstrap/config.go b/bootstrap/config.go
index 0772b0a..9499aeb 100644
--- a/bootstrap/config.go
+++ b/bootstrap/config.go
@@ -30,7 +30,7 @@
}
var (
- // These variables are the only configuration needed by the boostrap
+ // These variables are the only configuration needed by the bootstrap
// modules.
srcDir = bootstrapVariable("srcDir", func() string {
return SrcDir
diff --git a/bootstrap/glob.go b/bootstrap/glob.go
index 9841611..52dbf2f 100644
--- a/bootstrap/glob.go
+++ b/bootstrap/glob.go
@@ -131,8 +131,14 @@
depFile := fileListFile + ".d"
fileList := strings.Join(g.Files, "\n") + "\n"
- pathtools.WriteFileIfChanged(fileListFile, []byte(fileList), 0666)
- deptools.WriteDepFile(depFile, fileListFile, g.Deps)
+ err := pathtools.WriteFileIfChanged(absolutePath(fileListFile), []byte(fileList), 0666)
+ if err != nil {
+ panic(fmt.Errorf("error writing %s: %s", fileListFile, err))
+ }
+ err = deptools.WriteDepFile(absolutePath(depFile), fileListFile, g.Deps)
+ if err != nil {
+ panic(fmt.Errorf("error writing %s: %s", depFile, err))
+ }
GlobFile(ctx, g.Pattern, g.Excludes, fileListFile, depFile)
} else {
diff --git a/context.go b/context.go
index cedf3d8..1ad1588 100644
--- a/context.go
+++ b/context.go
@@ -96,15 +96,15 @@
// set during PrepareBuildActions
pkgNames map[*packageContext]string
liveGlobals *liveTracker
- globalVariables map[Variable]*ninjaString
+ globalVariables map[Variable]ninjaString
globalPools map[Pool]*poolDef
globalRules map[Rule]*ruleDef
// set during PrepareBuildActions
- ninjaBuildDir *ninjaString // The builddir special Ninja variable
- requiredNinjaMajor int // For the ninja_required_version variable
- requiredNinjaMinor int // For the ninja_required_version variable
- requiredNinjaMicro int // For the ninja_required_version variable
+ ninjaBuildDir ninjaString // The builddir special Ninja variable
+ requiredNinjaMajor int // For the ninja_required_version variable
+ requiredNinjaMinor int // For the ninja_required_version variable
+ requiredNinjaMicro int // For the ninja_required_version variable
subninjas []string
@@ -114,6 +114,7 @@
globs map[string]GlobPath
globLock sync.Mutex
+ srcDir string
fs pathtools.FileSystem
moduleListFile string
}
@@ -157,11 +158,19 @@
buildDefs []*buildDef
}
+type moduleAlias struct {
+ variantName string
+ variant variationMap
+ dependencyVariant variationMap
+ target *moduleInfo
+}
+
type moduleGroup struct {
name string
ninjaName string
modules []*moduleInfo
+ aliases []*moduleAlias
namespace Namespace
}
@@ -197,6 +206,7 @@
// set during each runMutator
splitModules []*moduleInfo
+ aliasTarget *moduleInfo
// set during PrepareBuildActions
actionDefs localBuildActions
@@ -242,6 +252,9 @@
type variationMap map[string]string
func (vm variationMap) clone() variationMap {
+ if vm == nil {
+ return nil
+ }
newVm := make(variationMap)
for k, v := range vm {
newVm[k] = v
@@ -436,6 +449,15 @@
c.nameInterface = i
}
+func (c *Context) SetSrcDir(path string) {
+ c.srcDir = path
+ c.fs = pathtools.NewOsFs(path)
+}
+
+func (c *Context) SrcDir() string {
+ return c.srcDir
+}
+
func singletonPkgPath(singleton Singleton) string {
typ := reflect.TypeOf(singleton)
for typ.Kind() == reflect.Ptr {
@@ -619,17 +641,19 @@
// which the future output will depend is returned. This list will include both
// Blueprints file paths as well as directory paths for cases where wildcard
// subdirs are found.
-func (c *Context) ParseBlueprintsFiles(rootFile string) (deps []string, errs []error) {
+func (c *Context) ParseBlueprintsFiles(rootFile string,
+ config interface{}) (deps []string, errs []error) {
+
baseDir := filepath.Dir(rootFile)
pathsToParse, err := c.ListModulePaths(baseDir)
if err != nil {
return nil, []error{err}
}
- return c.ParseFileList(baseDir, pathsToParse)
+ return c.ParseFileList(baseDir, pathsToParse, config)
}
-func (c *Context) ParseFileList(rootDir string, filePaths []string) (deps []string,
- errs []error) {
+func (c *Context) ParseFileList(rootDir string, filePaths []string,
+ config interface{}) (deps []string, errs []error) {
if len(filePaths) < 1 {
return nil, []error{fmt.Errorf("no paths provided to parse")}
@@ -637,7 +661,12 @@
c.dependenciesReady = false
- moduleCh := make(chan *moduleInfo)
+ type newModuleInfo struct {
+ *moduleInfo
+ added chan<- struct{}
+ }
+
+ moduleCh := make(chan newModuleInfo)
errsCh := make(chan []error)
doneCh := make(chan struct{})
var numErrs uint32
@@ -649,24 +678,47 @@
return
}
+ addedCh := make(chan struct{})
+
+ var scopedModuleFactories map[string]ModuleFactory
+
+ var addModule func(module *moduleInfo) []error
+ addModule = func(module *moduleInfo) (errs []error) {
+ moduleCh <- newModuleInfo{module, addedCh}
+ <-addedCh
+ var newModules []*moduleInfo
+ newModules, errs = runAndRemoveLoadHooks(c, config, module, &scopedModuleFactories)
+ if len(errs) > 0 {
+ return errs
+ }
+ for _, n := range newModules {
+ errs = addModule(n)
+ if len(errs) > 0 {
+ return errs
+ }
+ }
+ return nil
+ }
+
for _, def := range file.Defs {
- var module *moduleInfo
- var errs []error
switch def := def.(type) {
case *parser.Module:
- module, errs = c.processModuleDef(def, file.Name)
+ module, errs := c.processModuleDef(def, file.Name, scopedModuleFactories)
+ if len(errs) == 0 && module != nil {
+ errs = addModule(module)
+ }
+
+ if len(errs) > 0 {
+ atomic.AddUint32(&numErrs, uint32(len(errs)))
+ errsCh <- errs
+ }
+
case *parser.Assignment:
// Already handled via Scope object
default:
panic("unknown definition type")
}
- if len(errs) > 0 {
- atomic.AddUint32(&numErrs, uint32(len(errs)))
- errsCh <- errs
- } else if module != nil {
- moduleCh <- module
- }
}
}
@@ -686,7 +738,10 @@
case newErrs := <-errsCh:
errs = append(errs, newErrs...)
case module := <-moduleCh:
- newErrs := c.addModule(module)
+ newErrs := c.addModule(module.moduleInfo)
+ if module.added != nil {
+ module.added <- struct{}{}
+ }
if len(newErrs) > 0 {
errs = append(errs, newErrs...)
}
@@ -877,6 +932,10 @@
c.fs = pathtools.MockFs(files)
}
+func (c *Context) SetFs(fs pathtools.FileSystem) {
+ c.fs = fs
+}
+
// openAndParse opens and parses a single Blueprints file, and returns the results
func (c *Context) openAndParse(filename string, scope *parser.Scope, rootDir string,
parent *fileParseContext) (file *parser.File,
@@ -1143,8 +1202,8 @@
}
for i := range newProperties {
- dst := reflect.ValueOf(newProperties[i]).Elem()
- src := reflect.ValueOf(origModule.properties[i]).Elem()
+ dst := reflect.ValueOf(newProperties[i])
+ src := reflect.ValueOf(origModule.properties[i])
proptools.CopyProperties(dst, src)
}
@@ -1178,6 +1237,9 @@
}
newVariant := origModule.variant.clone()
+ if newVariant == nil {
+ newVariant = make(variationMap)
+ }
newVariant[mutatorName] = variationName
m := *origModule
@@ -1261,6 +1323,19 @@
return strings.Join(names, ", ")
}
+func (c *Context) prettyPrintGroupVariants(group *moduleGroup) string {
+ var variants []string
+ for _, mod := range group.modules {
+ variants = append(variants, c.prettyPrintVariant(mod.variant))
+ }
+ for _, mod := range group.aliases {
+ variants = append(variants, c.prettyPrintVariant(mod.variant)+
+ "(alias to "+c.prettyPrintVariant(mod.target.variant)+")")
+ }
+ sort.Strings(variants)
+ return strings.Join(variants, "\n ")
+}
+
func (c *Context) newModule(factory ModuleFactory) *moduleInfo {
logicModule, properties := factory()
@@ -1275,9 +1350,12 @@
}
func (c *Context) processModuleDef(moduleDef *parser.Module,
- relBlueprintsFile string) (*moduleInfo, []error) {
+ relBlueprintsFile string, scopedModuleFactories map[string]ModuleFactory) (*moduleInfo, []error) {
factory, ok := c.moduleFactories[moduleDef.Type]
+ if !ok && scopedModuleFactories != nil {
+ factory, ok = scopedModuleFactories[moduleDef.Type]
+ }
if !ok {
if c.ignoreUnknownModuleTypes {
return nil, nil
@@ -1296,8 +1374,17 @@
module.relBlueprintsFile = relBlueprintsFile
- propertyMap, errs := unpackProperties(moduleDef.Properties, module.properties...)
+ propertyMap, errs := proptools.UnpackProperties(moduleDef.Properties, module.properties...)
if len(errs) > 0 {
+ for i, err := range errs {
+ if unpackErr, ok := err.(*proptools.UnpackError); ok {
+ err = &BlueprintError{
+ Err: unpackErr.Err,
+ Pos: unpackErr.Pos,
+ }
+ errs[i] = err
+ }
+ }
return nil, errs
}
@@ -1410,9 +1497,9 @@
// findMatchingVariant searches the moduleGroup for a module with the same variant as module,
// and returns the matching module, or nil if one is not found.
-func (c *Context) findMatchingVariant(module *moduleInfo, possible []*moduleInfo, reverse bool) *moduleInfo {
- if len(possible) == 1 {
- return possible[0]
+func (c *Context) findMatchingVariant(module *moduleInfo, possible *moduleGroup, reverse bool) *moduleInfo {
+ if len(possible.modules) == 1 {
+ return possible.modules[0]
} else {
var variantToMatch variationMap
if !reverse {
@@ -1423,11 +1510,16 @@
// For reverse dependency, use all the variants
variantToMatch = module.variant
}
- for _, m := range possible {
+ for _, m := range possible.modules {
if m.variant.equal(variantToMatch) {
return m
}
}
+ for _, m := range possible.aliases {
+ if m.variant.equal(variantToMatch) {
+ return m.target
+ }
+ }
}
return nil
@@ -1445,7 +1537,7 @@
}}
}
- possibleDeps := c.modulesFromName(depName, module.namespace())
+ possibleDeps := c.moduleGroupFromName(depName, module.namespace())
if possibleDeps == nil {
return c.discoveredMissingDependencies(module, depName)
}
@@ -1456,17 +1548,11 @@
return nil
}
- variants := make([]string, len(possibleDeps))
- for i, mod := range possibleDeps {
- variants[i] = c.prettyPrintVariant(mod.variant)
- }
- sort.Strings(variants)
-
return []error{&BlueprintError{
Err: fmt.Errorf("dependency %q of %q missing variant:\n %s\navailable variants:\n %s",
depName, module.Name(),
c.prettyPrintVariant(module.dependencyVariant),
- strings.Join(variants, "\n ")),
+ c.prettyPrintGroupVariants(possibleDeps)),
Pos: module.pos,
}}
}
@@ -1479,7 +1565,7 @@
}}
}
- possibleDeps := c.modulesFromName(destName, module.namespace())
+ possibleDeps := c.moduleGroupFromName(destName, module.namespace())
if possibleDeps == nil {
return nil, []error{&BlueprintError{
Err: fmt.Errorf("%q has a reverse dependency on undefined module %q",
@@ -1492,17 +1578,11 @@
return m, nil
}
- variants := make([]string, len(possibleDeps))
- for i, mod := range possibleDeps {
- variants[i] = c.prettyPrintVariant(mod.variant)
- }
- sort.Strings(variants)
-
return nil, []error{&BlueprintError{
Err: fmt.Errorf("reverse dependency %q of %q missing variant:\n %s\navailable variants:\n %s",
destName, module.Name(),
c.prettyPrintVariant(module.dependencyVariant),
- strings.Join(variants, "\n ")),
+ c.prettyPrintGroupVariants(possibleDeps)),
Pos: module.pos,
}}
}
@@ -1513,7 +1593,7 @@
panic("BaseDependencyTag is not allowed to be used directly!")
}
- possibleDeps := c.modulesFromName(depName, module.namespace())
+ possibleDeps := c.moduleGroupFromName(depName, module.namespace())
if possibleDeps == nil {
return c.discoveredMissingDependencies(module, depName)
}
@@ -1524,55 +1604,67 @@
var newVariant variationMap
if !far {
newVariant = module.dependencyVariant.clone()
- } else {
- newVariant = make(variationMap)
}
for _, v := range variations {
+ if newVariant == nil {
+ newVariant = make(variationMap)
+ }
newVariant[v.Mutator] = v.Variation
}
- for _, m := range possibleDeps {
- var found bool
+ check := func(variant variationMap) bool {
if far {
- found = m.variant.subset(newVariant)
+ return variant.subset(newVariant)
} else {
- found = m.variant.equal(newVariant)
- }
- if found {
- if module == m {
- return []error{&BlueprintError{
- Err: fmt.Errorf("%q depends on itself", depName),
- Pos: module.pos,
- }}
- }
- // AddVariationDependency allows adding a dependency on itself, but only if
- // that module is earlier in the module list than this one, since we always
- // run GenerateBuildActions in order for the variants of a module
- if m.group == module.group && beforeInModuleList(module, m, module.group.modules) {
- return []error{&BlueprintError{
- Err: fmt.Errorf("%q depends on later version of itself", depName),
- Pos: module.pos,
- }}
- }
- module.newDirectDeps = append(module.newDirectDeps, depInfo{m, tag})
- atomic.AddUint32(&c.depsModified, 1)
- return nil
+ return variant.equal(newVariant)
}
}
- variants := make([]string, len(possibleDeps))
- for i, mod := range possibleDeps {
- variants[i] = c.prettyPrintVariant(mod.variant)
+ var foundDep *moduleInfo
+ for _, m := range possibleDeps.modules {
+ if check(m.variant) {
+ foundDep = m
+ break
+ }
}
- sort.Strings(variants)
- return []error{&BlueprintError{
- Err: fmt.Errorf("dependency %q of %q missing variant:\n %s\navailable variants:\n %s",
- depName, module.Name(),
- c.prettyPrintVariant(newVariant),
- strings.Join(variants, "\n ")),
- Pos: module.pos,
- }}
+ if foundDep == nil {
+ for _, m := range possibleDeps.aliases {
+ if check(m.variant) {
+ foundDep = m.target
+ break
+ }
+ }
+ }
+
+ if foundDep == nil {
+ return []error{&BlueprintError{
+ Err: fmt.Errorf("dependency %q of %q missing variant:\n %s\navailable variants:\n %s",
+ depName, module.Name(),
+ c.prettyPrintVariant(newVariant),
+ c.prettyPrintGroupVariants(possibleDeps)),
+ Pos: module.pos,
+ }}
+ }
+
+ if module == foundDep {
+ return []error{&BlueprintError{
+ Err: fmt.Errorf("%q depends on itself", depName),
+ Pos: module.pos,
+ }}
+ }
+ // AddVariationDependency allows adding a dependency on itself, but only if
+ // that module is earlier in the module list than this one, since we always
+ // run GenerateBuildActions in order for the variants of a module
+ if foundDep.group == module.group && beforeInModuleList(module, foundDep, module.group.modules) {
+ return []error{&BlueprintError{
+ Err: fmt.Errorf("%q depends on later version of itself", depName),
+ Pos: module.pos,
+ }}
+ }
+ module.newDirectDeps = append(module.newDirectDeps, depInfo{foundDep, tag})
+ atomic.AddUint32(&c.depsModified, 1)
+ return nil
}
func (c *Context) addInterVariantDependency(origModule *moduleInfo, tag DependencyTag,
@@ -2166,6 +2258,16 @@
group.modules, i = spliceModules(group.modules, i, module.splitModules)
}
+ // Create any new aliases.
+ if module.aliasTarget != nil {
+ group.aliases = append(group.aliases, &moduleAlias{
+ variantName: module.variantName,
+ variant: module.variant,
+ dependencyVariant: module.dependencyVariant,
+ target: module.aliasTarget,
+ })
+ }
+
// Fix up any remaining dependencies on modules that were split into variants
// by replacing them with the first variant
for j, dep := range module.directDeps {
@@ -2182,6 +2284,21 @@
module.directDeps = append(module.directDeps, module.newDirectDeps...)
module.newDirectDeps = nil
}
+
+ // Forward or delete any dangling aliases.
+ for i := 0; i < len(group.aliases); i++ {
+ alias := group.aliases[i]
+
+ if alias.target.logicModule == nil {
+ if alias.target.aliasTarget != nil {
+ alias.target = alias.target.aliasTarget
+ } else {
+ // The alias was left dangling, remove it.
+ group.aliases = append(group.aliases[:i], group.aliases[i+1:]...)
+ i--
+ }
+ }
+ }
}
// Add in any new reverse dependencies that were added by the mutator
@@ -2510,18 +2627,24 @@
}
func (c *Context) moduleMatchingVariant(module *moduleInfo, name string) *moduleInfo {
- targets := c.modulesFromName(name, module.namespace())
+ group := c.moduleGroupFromName(name, module.namespace())
- if targets == nil {
+ if group == nil {
return nil
}
- for _, m := range targets {
+ for _, m := range group.modules {
if module.variantName == m.variantName {
return m
}
}
+ for _, m := range group.aliases {
+ if module.variantName == m.variantName {
+ return m.target
+ }
+ }
+
return nil
}
@@ -2573,10 +2696,10 @@
}
}
-func (c *Context) modulesFromName(name string, namespace Namespace) []*moduleInfo {
+func (c *Context) moduleGroupFromName(name string, namespace Namespace) *moduleGroup {
group, exists := c.nameInterface.ModuleFromName(name, namespace)
if exists {
- return group.modules
+ return group.moduleGroup
}
return nil
}
@@ -2665,7 +2788,7 @@
}
}
-func (c *Context) setNinjaBuildDir(value *ninjaString) {
+func (c *Context) setNinjaBuildDir(value ninjaString) {
if c.ninjaBuildDir == nil {
c.ninjaBuildDir = value
}
@@ -2731,7 +2854,7 @@
}
func (c *Context) checkForVariableReferenceCycles(
- variables map[Variable]*ninjaString, pkgNames map[*packageContext]string) {
+ variables map[Variable]ninjaString, pkgNames map[*packageContext]string) {
visited := make(map[Variable]bool) // variables that were already checked
checking := make(map[Variable]bool) // variables actively being checked
@@ -2744,7 +2867,7 @@
defer delete(checking, v)
value := variables[v]
- for _, dep := range value.variables {
+ for _, dep := range value.Variables() {
if checking[dep] {
// This is a cycle.
return []Variable{dep, v}
@@ -3229,7 +3352,7 @@
// First visit variables on which this variable depends.
value := c.globalVariables[v]
- for _, dep := range value.variables {
+ for _, dep := range value.Variables() {
if !visited[dep] {
err := walk(dep)
if err != nil {
diff --git a/context_test.go b/context_test.go
index 0d783dc..d4e9bf7 100644
--- a/context_test.go
+++ b/context_test.go
@@ -168,7 +168,7 @@
ctx.RegisterModuleType("foo_module", newFooModule)
ctx.RegisterModuleType("bar_module", newBarModule)
- _, errs := ctx.ParseBlueprintsFiles("Blueprints")
+ _, errs := ctx.ParseBlueprintsFiles("Blueprints", nil)
if len(errs) > 0 {
t.Errorf("unexpected parse errors:")
for _, err := range errs {
@@ -188,7 +188,7 @@
var outputDown string
var outputUp string
- topModule := ctx.modulesFromName("A", nil)[0]
+ topModule := ctx.moduleGroupFromName("A", nil).modules[0]
ctx.walkDeps(topModule, false,
func(dep depInfo, parent *moduleInfo) bool {
outputDown += ctx.ModuleName(dep.module.logicModule)
@@ -260,7 +260,7 @@
ctx.RegisterModuleType("foo_module", newFooModule)
ctx.RegisterModuleType("bar_module", newBarModule)
- _, errs := ctx.ParseBlueprintsFiles("Blueprints")
+ _, errs := ctx.ParseBlueprintsFiles("Blueprints", nil)
if len(errs) > 0 {
t.Errorf("unexpected parse errors:")
for _, err := range errs {
@@ -280,7 +280,7 @@
var outputDown string
var outputUp string
- topModule := ctx.modulesFromName("A", nil)[0]
+ topModule := ctx.moduleGroupFromName("A", nil).modules[0]
ctx.walkDeps(topModule, true,
func(dep depInfo, parent *moduleInfo) bool {
outputDown += ctx.ModuleName(dep.module.logicModule)
@@ -316,7 +316,7 @@
ctx.RegisterModuleType("foo_module", newFooModule)
ctx.RegisterModuleType("bar_module", newBarModule)
- _, errs := ctx.ParseBlueprintsFiles("Blueprints")
+ _, errs := ctx.ParseBlueprintsFiles("Blueprints", nil)
if len(errs) > 0 {
t.Errorf("unexpected parse errors:")
for _, err := range errs {
@@ -334,10 +334,10 @@
t.FailNow()
}
- a := ctx.modulesFromName("A", nil)[0].logicModule.(*fooModule)
- b := ctx.modulesFromName("B", nil)[0].logicModule.(*barModule)
- c := ctx.modulesFromName("C", nil)[0].logicModule.(*barModule)
- d := ctx.modulesFromName("D", nil)[0].logicModule.(*fooModule)
+ a := ctx.moduleGroupFromName("A", nil).modules[0].logicModule.(*fooModule)
+ b := ctx.moduleGroupFromName("B", nil).modules[0].logicModule.(*barModule)
+ c := ctx.moduleGroupFromName("C", nil).modules[0].logicModule.(*barModule)
+ d := ctx.moduleGroupFromName("D", nil).modules[0].logicModule.(*fooModule)
checkDeps := func(m Module, expected string) {
var deps []string
@@ -499,7 +499,7 @@
ctx.RegisterModuleType("foo_module", newFooModule)
ctx.RegisterModuleType("bar_module", newBarModule)
- _, errs := ctx.ParseBlueprintsFiles("Blueprints")
+ _, errs := ctx.ParseBlueprintsFiles("Blueprints", nil)
expectedErrs := []error{
errors.New(`Blueprints:6:4: property 'name' is missing from a module`),
diff --git a/live_tracker.go b/live_tracker.go
index 5e13a87..40e1930 100644
--- a/live_tracker.go
+++ b/live_tracker.go
@@ -24,7 +24,7 @@
sync.Mutex
config interface{} // Used to evaluate variable, rule, and pool values.
- variables map[Variable]*ninjaString
+ variables map[Variable]ninjaString
pools map[Pool]*poolDef
rules map[Rule]*ruleDef
}
@@ -32,7 +32,7 @@
func newLiveTracker(config interface{}) *liveTracker {
return &liveTracker{
config: config,
- variables: make(map[Variable]*ninjaString),
+ variables: make(map[Variable]ninjaString),
pools: make(map[Pool]*poolDef),
rules: make(map[Rule]*ruleDef),
}
@@ -170,7 +170,7 @@
return nil
}
-func (l *liveTracker) addNinjaStringListDeps(list []*ninjaString) error {
+func (l *liveTracker) addNinjaStringListDeps(list []ninjaString) error {
for _, str := range list {
err := l.addNinjaStringDeps(str)
if err != nil {
@@ -180,8 +180,8 @@
return nil
}
-func (l *liveTracker) addNinjaStringDeps(str *ninjaString) error {
- for _, v := range str.variables {
+func (l *liveTracker) addNinjaStringDeps(str ninjaString) error {
+ for _, v := range str.Variables() {
err := l.addVariable(v)
if err != nil {
return err
diff --git a/module_ctx.go b/module_ctx.go
index be5d974..012142e 100644
--- a/module_ctx.go
+++ b/module_ctx.go
@@ -17,6 +17,7 @@
import (
"fmt"
"path/filepath"
+ "sync"
"text/scanner"
"github.com/google/blueprint/pathtools"
@@ -120,7 +121,7 @@
DynamicDependencies(DynamicDependerModuleContext) []string
}
-type BaseModuleContext interface {
+type EarlyModuleContext interface {
// Module returns the current module as a Module. It should rarely be necessary, as the module already has a
// reference to itself.
Module() Module
@@ -136,6 +137,10 @@
// RegisterModuleType.
ModuleType() string
+ // BlueprintFile returns the name of the blueprint file that contains the definition of this
+ // module.
+ BlueprintsFile() string
+
// Config returns the config object that was passed to Context.PrepareBuildActions.
Config() interface{}
@@ -179,6 +184,13 @@
// default SimpleNameInterface if Context.SetNameInterface was not called.
Namespace() Namespace
+ // ModuleFactories returns a map of all of the global ModuleFactories by name.
+ ModuleFactories() map[string]ModuleFactory
+}
+
+type BaseModuleContext interface {
+ EarlyModuleContext
+
// GetDirectDepWithTag returns the Module the direct dependency with the specified name, or nil if
// none exists. It panics if the dependency does not have the specified tag.
GetDirectDepWithTag(name string, tag DependencyTag) Module
@@ -343,6 +355,10 @@
return filepath.Dir(d.module.relBlueprintsFile)
}
+func (d *baseModuleContext) BlueprintsFile() string {
+ return d.module.relBlueprintsFile
+}
+
func (d *baseModuleContext) Config() interface{} {
return d.config
}
@@ -597,6 +613,14 @@
m.ninjaFileDeps = append(m.ninjaFileDeps, deps...)
}
+func (m *baseModuleContext) ModuleFactories() map[string]ModuleFactory {
+ ret := make(map[string]ModuleFactory)
+ for k, v := range m.context.moduleFactories {
+ ret[k] = v
+ }
+ return ret
+}
+
func (m *moduleContext) ModuleSubDir() string {
return m.module.variantName
}
@@ -785,6 +809,13 @@
// specified name with the current variant of this module. Replacements don't take effect until
// after the mutator pass is finished.
ReplaceDependencies(string)
+
+ // AliasVariation takes a variationName that was passed to CreateVariations for this module, and creates an
+ // alias from the current variant to the new variant. The alias will be valid until the next time a mutator
+ // calls CreateVariations or CreateLocalVariations on this module without also calling AliasVariation. The
+ // alias can be used to add dependencies on the newly created variant using the variant map from before
+ // CreateVariations was run.
+ AliasVariation(variationName string)
}
// A Mutator function is called for each Module, and can use
@@ -838,6 +869,9 @@
for i, module := range modules {
ret = append(ret, module.logicModule)
if !local {
+ if module.dependencyVariant == nil {
+ module.dependencyVariant = make(variationMap)
+ }
module.dependencyVariant[mctx.name] = variationNames[i]
}
}
@@ -854,6 +888,25 @@
return ret
}
+func (mctx *mutatorContext) AliasVariation(variationName string) {
+ if mctx.module.aliasTarget != nil {
+ panic(fmt.Errorf("AliasVariation already called"))
+ }
+
+ for _, variant := range mctx.newVariations {
+ if variant.variant[mctx.name] == variationName {
+ mctx.module.aliasTarget = variant
+ return
+ }
+ }
+
+ var foundVariations []string
+ for _, variant := range mctx.newVariations {
+ foundVariations = append(foundVariations, variant.variant[mctx.name])
+ }
+ panic(fmt.Errorf("no %q variation in module variations %q", variationName, foundVariations))
+}
+
func (mctx *mutatorContext) SetDependencyVariation(variationName string) {
mctx.context.convertDepsToVariation(mctx.module, mctx.name, variationName, nil)
}
@@ -966,3 +1019,107 @@
func (s *SimpleName) Name() string {
return s.Properties.Name
}
+
+// Load Hooks
+
+type LoadHookContext interface {
+ EarlyModuleContext
+
+ // CreateModule creates a new module by calling the factory method for the specified moduleType, and applies
+ // the specified property structs to it as if the properties were set in a blueprint file.
+ CreateModule(ModuleFactory, ...interface{}) Module
+
+ // RegisterScopedModuleType creates a new module type that is scoped to the current Blueprints
+ // file.
+ RegisterScopedModuleType(name string, factory ModuleFactory)
+}
+
+func (l *loadHookContext) CreateModule(factory ModuleFactory, props ...interface{}) Module {
+ module := l.context.newModule(factory)
+
+ module.relBlueprintsFile = l.module.relBlueprintsFile
+ module.pos = l.module.pos
+ module.propertyPos = l.module.propertyPos
+ module.createdBy = l.module
+
+ for _, p := range props {
+ err := proptools.AppendMatchingProperties(module.properties, p, nil)
+ if err != nil {
+ panic(err)
+ }
+ }
+
+ l.newModules = append(l.newModules, module)
+
+ return module.logicModule
+}
+
+func (l *loadHookContext) RegisterScopedModuleType(name string, factory ModuleFactory) {
+ if _, exists := l.context.moduleFactories[name]; exists {
+ panic(fmt.Errorf("A global module type named %q already exists", name))
+ }
+
+ if _, exists := (*l.scopedModuleFactories)[name]; exists {
+ panic(fmt.Errorf("A module type named %q already exists in this scope", name))
+ }
+
+ if *l.scopedModuleFactories == nil {
+ (*l.scopedModuleFactories) = make(map[string]ModuleFactory)
+ }
+
+ (*l.scopedModuleFactories)[name] = factory
+}
+
+type loadHookContext struct {
+ baseModuleContext
+ newModules []*moduleInfo
+ scopedModuleFactories *map[string]ModuleFactory
+}
+
+type LoadHook func(ctx LoadHookContext)
+
+// Load hooks need to be added by module factories, which don't have any parameter to get to the
+// Context, and only produce a Module interface with no base implementation, so the load hooks
+// must be stored in a global map. The key is a pointer allocated by the module factory, so there
+// is no chance of collisions even if tests are running in parallel with multiple contexts. The
+// contents should be short-lived, they are added during a module factory and removed immediately
+// after the module factory returns.
+var pendingHooks sync.Map
+
+func AddLoadHook(module Module, hook LoadHook) {
+ // Only one goroutine can be processing a given module, so no additional locking is required
+ // for the slice stored in the sync.Map.
+ v, exists := pendingHooks.Load(module)
+ if !exists {
+ v, _ = pendingHooks.LoadOrStore(module, new([]LoadHook))
+ }
+ hooks := v.(*[]LoadHook)
+ *hooks = append(*hooks, hook)
+}
+
+func runAndRemoveLoadHooks(ctx *Context, config interface{}, module *moduleInfo,
+ scopedModuleFactories *map[string]ModuleFactory) (newModules []*moduleInfo, errs []error) {
+
+ if v, exists := pendingHooks.Load(module.logicModule); exists {
+ hooks := v.(*[]LoadHook)
+ mctx := &loadHookContext{
+ baseModuleContext: baseModuleContext{
+ context: ctx,
+ config: config,
+ module: module,
+ },
+ scopedModuleFactories: scopedModuleFactories,
+ }
+
+ for _, hook := range *hooks {
+ hook(mctx)
+ newModules = append(newModules, mctx.newModules...)
+ errs = append(errs, mctx.errs...)
+ }
+ pendingHooks.Delete(module.logicModule)
+
+ return newModules, errs
+ }
+
+ return nil, nil
+}
diff --git a/module_ctx_test.go b/module_ctx_test.go
new file mode 100644
index 0000000..e72be8c
--- /dev/null
+++ b/module_ctx_test.go
@@ -0,0 +1,197 @@
+// Copyright 2019 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package blueprint
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+type moduleCtxTestModule struct {
+ SimpleName
+}
+
+func newModuleCtxTestModule() (Module, []interface{}) {
+ m := &moduleCtxTestModule{}
+ return m, []interface{}{&m.SimpleName.Properties}
+}
+
+func (f *moduleCtxTestModule) GenerateBuildActions(ModuleContext) {
+}
+
+func noCreateAliasMutator(name string) func(ctx BottomUpMutatorContext) {
+ return func(ctx BottomUpMutatorContext) {
+ if ctx.ModuleName() == name {
+ ctx.CreateVariations("a", "b")
+ }
+ }
+}
+
+func createAliasMutator(name string) func(ctx BottomUpMutatorContext) {
+ return func(ctx BottomUpMutatorContext) {
+ if ctx.ModuleName() == name {
+ ctx.CreateVariations("a", "b")
+ ctx.AliasVariation("b")
+ }
+ }
+}
+
+func addVariantDepsMutator(variants []Variation, tag DependencyTag, from, to string) func(ctx BottomUpMutatorContext) {
+ return func(ctx BottomUpMutatorContext) {
+ if ctx.ModuleName() == from {
+ ctx.AddVariationDependencies(variants, tag, to)
+ }
+ }
+}
+
+func TestAliases(t *testing.T) {
+ runWithFailures := func(ctx *Context, expectedErr string) {
+ t.Helper()
+ bp := `
+ test {
+ name: "foo",
+ }
+
+ test {
+ name: "bar",
+ }
+ `
+
+ mockFS := map[string][]byte{
+ "Blueprints": []byte(bp),
+ }
+
+ ctx.MockFileSystem(mockFS)
+
+ _, errs := ctx.ParseFileList(".", []string{"Blueprints"}, nil)
+ if len(errs) > 0 {
+ t.Errorf("unexpected parse errors:")
+ for _, err := range errs {
+ t.Errorf(" %s", err)
+ }
+ }
+
+ _, errs = ctx.ResolveDependencies(nil)
+ if len(errs) > 0 {
+ if expectedErr == "" {
+ t.Errorf("unexpected dep errors:")
+ for _, err := range errs {
+ t.Errorf(" %s", err)
+ }
+ } else {
+ for _, err := range errs {
+ if strings.Contains(err.Error(), expectedErr) {
+ continue
+ } else {
+ t.Errorf("unexpected dep error: %s", err)
+ }
+ }
+ }
+ } else if expectedErr != "" {
+ t.Errorf("missing dep error: %s", expectedErr)
+ }
+ }
+
+ run := func(ctx *Context) {
+ t.Helper()
+ runWithFailures(ctx, "")
+ }
+
+ t.Run("simple", func(t *testing.T) {
+ // Creates a module "bar" with variants "a" and "b" and alias "" -> "b".
+ // Tests a dependency from "foo" to "bar" variant "b" through alias "".
+ ctx := NewContext()
+ ctx.RegisterModuleType("test", newModuleCtxTestModule)
+ ctx.RegisterBottomUpMutator("1", createAliasMutator("bar"))
+ ctx.RegisterBottomUpMutator("2", addVariantDepsMutator(nil, nil, "foo", "bar"))
+
+ run(ctx)
+
+ foo := ctx.moduleGroupFromName("foo", nil).modules[0]
+ barB := ctx.moduleGroupFromName("bar", nil).modules[1]
+
+ if g, w := barB.variantName, "b"; g != w {
+ t.Fatalf("expected bar.modules[1] variant to be %q, got %q", w, g)
+ }
+
+ if g, w := foo.forwardDeps, []*moduleInfo{barB}; !reflect.DeepEqual(g, w) {
+ t.Fatalf("expected foo deps to be %q, got %q", w, g)
+ }
+ })
+
+ t.Run("chained", func(t *testing.T) {
+ // Creates a module "bar" with variants "a_a", "a_b", "b_a" and "b_b" and aliases "" -> "b_b",
+ // "a" -> "a_b", and "b" -> "b_b".
+ // Tests a dependency from "foo" to "bar" variant "b_b" through alias "".
+ ctx := NewContext()
+ ctx.RegisterModuleType("test", newModuleCtxTestModule)
+ ctx.RegisterBottomUpMutator("1", createAliasMutator("bar"))
+ ctx.RegisterBottomUpMutator("2", createAliasMutator("bar"))
+ ctx.RegisterBottomUpMutator("3", addVariantDepsMutator(nil, nil, "foo", "bar"))
+
+ run(ctx)
+
+ foo := ctx.moduleGroupFromName("foo", nil).modules[0]
+ barBB := ctx.moduleGroupFromName("bar", nil).modules[3]
+
+ if g, w := barBB.variantName, "b_b"; g != w {
+ t.Fatalf("expected bar.modules[3] variant to be %q, got %q", w, g)
+ }
+
+ if g, w := foo.forwardDeps, []*moduleInfo{barBB}; !reflect.DeepEqual(g, w) {
+ t.Fatalf("expected foo deps to be %q, got %q", w, g)
+ }
+ })
+
+ t.Run("chained2", func(t *testing.T) {
+ // Creates a module "bar" with variants "a_a", "a_b", "b_a" and "b_b" and aliases "" -> "b_b",
+ // "a" -> "a_b", and "b" -> "b_b".
+ // Tests a dependency from "foo" to "bar" variant "a_b" through alias "a".
+ ctx := NewContext()
+ ctx.RegisterModuleType("test", newModuleCtxTestModule)
+ ctx.RegisterBottomUpMutator("1", createAliasMutator("bar"))
+ ctx.RegisterBottomUpMutator("2", createAliasMutator("bar"))
+ ctx.RegisterBottomUpMutator("3", addVariantDepsMutator([]Variation{{"1", "a"}}, nil, "foo", "bar"))
+
+ run(ctx)
+
+ foo := ctx.moduleGroupFromName("foo", nil).modules[0]
+ barAB := ctx.moduleGroupFromName("bar", nil).modules[1]
+
+ if g, w := barAB.variantName, "a_b"; g != w {
+ t.Fatalf("expected bar.modules[1] variant to be %q, got %q", w, g)
+ }
+
+ if g, w := foo.forwardDeps, []*moduleInfo{barAB}; !reflect.DeepEqual(g, w) {
+ t.Fatalf("expected foo deps to be %q, got %q", w, g)
+ }
+ })
+
+ t.Run("removed dangling alias", func(t *testing.T) {
+ // Creates a module "bar" with variants "a" and "b" and aliases "" -> "b", then splits the variants into
+ // "a_a", "a_b", "b_a" and "b_b" without creating new aliases.
+ // Tests a dependency from "foo" to removed "bar" alias "" fails.
+ ctx := NewContext()
+ ctx.RegisterModuleType("test", newModuleCtxTestModule)
+ ctx.RegisterBottomUpMutator("1", createAliasMutator("bar"))
+ ctx.RegisterBottomUpMutator("2", noCreateAliasMutator("bar"))
+ ctx.RegisterBottomUpMutator("3", addVariantDepsMutator(nil, nil, "foo", "bar"))
+
+ runWithFailures(ctx, `dependency "bar" of "foo" missing variant:`+"\n \n"+
+ "available variants:"+
+ "\n 1:a, 2:a\n 1:a, 2:b\n 1:b, 2:a\n 1:b, 2:b")
+ })
+}
diff --git a/ninja_defs.go b/ninja_defs.go
index 61846fe..c5d0e4b 100644
--- a/ninja_defs.go
+++ b/ninja_defs.go
@@ -128,11 +128,11 @@
// A ruleDef describes a rule definition. It does not include the name of the
// rule.
type ruleDef struct {
- CommandDeps []*ninjaString
- CommandOrderOnly []*ninjaString
+ CommandDeps []ninjaString
+ CommandOrderOnly []ninjaString
Comment string
Pool Pool
- Variables map[string]*ninjaString
+ Variables map[string]ninjaString
}
func parseRuleParams(scope scope, params *RuleParams) (*ruleDef,
@@ -141,7 +141,7 @@
r := &ruleDef{
Comment: params.Comment,
Pool: params.Pool,
- Variables: make(map[string]*ninjaString),
+ Variables: make(map[string]ninjaString),
}
if params.Command == "" {
@@ -252,13 +252,13 @@
Comment string
Rule Rule
RuleDef *ruleDef
- Outputs []*ninjaString
- ImplicitOutputs []*ninjaString
- Inputs []*ninjaString
- Implicits []*ninjaString
- OrderOnly []*ninjaString
- Args map[Variable]*ninjaString
- Variables map[string]*ninjaString
+ Outputs []ninjaString
+ ImplicitOutputs []ninjaString
+ Inputs []ninjaString
+ Implicits []ninjaString
+ OrderOnly []ninjaString
+ Args map[Variable]ninjaString
+ Variables map[string]ninjaString
Optional bool
}
@@ -273,9 +273,9 @@
Rule: rule,
}
- setVariable := func(name string, value *ninjaString) {
+ setVariable := func(name string, value ninjaString) {
if b.Variables == nil {
- b.Variables = make(map[string]*ninjaString)
+ b.Variables = make(map[string]ninjaString)
}
b.Variables[name] = value
}
@@ -339,7 +339,7 @@
argNameScope := rule.scope()
if len(params.Args) > 0 {
- b.Args = make(map[Variable]*ninjaString)
+ b.Args = make(map[Variable]ninjaString)
for name, value := range params.Args {
if !rule.isArg(name) {
return nil, fmt.Errorf("unknown argument %q", name)
@@ -419,7 +419,7 @@
return nw.BlankLine()
}
-func valueList(list []*ninjaString, pkgNames map[*packageContext]string,
+func valueList(list []ninjaString, pkgNames map[*packageContext]string,
escaper *strings.Replacer) []string {
result := make([]string, len(list))
@@ -429,7 +429,7 @@
return result
}
-func writeVariables(nw *ninjaWriter, variables map[string]*ninjaString,
+func writeVariables(nw *ninjaWriter, variables map[string]ninjaString,
pkgNames map[*packageContext]string) error {
var keys []string
for k := range variables {
diff --git a/ninja_strings.go b/ninja_strings.go
index 5b8767d..190cae9 100644
--- a/ninja_strings.go
+++ b/ninja_strings.go
@@ -34,21 +34,28 @@
":", "$:")
)
-type ninjaString struct {
+type ninjaString interface {
+ Value(pkgNames map[*packageContext]string) string
+ ValueWithEscaper(pkgNames map[*packageContext]string, escaper *strings.Replacer) string
+ Eval(variables map[Variable]ninjaString) (string, error)
+ Variables() []Variable
+}
+
+type varNinjaString struct {
strings []string
variables []Variable
}
+type literalNinjaString string
+
type scope interface {
LookupVariable(name string) (Variable, error)
IsRuleVisible(rule Rule) bool
IsPoolVisible(pool Pool) bool
}
-func simpleNinjaString(str string) *ninjaString {
- return &ninjaString{
- strings: []string{str},
- }
+func simpleNinjaString(str string) ninjaString {
+ return literalNinjaString(str)
}
type parseState struct {
@@ -57,7 +64,7 @@
pendingStr string
stringStart int
varStart int
- result *ninjaString
+ result *varNinjaString
}
func (ps *parseState) pushVariable(v Variable) {
@@ -84,10 +91,16 @@
// parseNinjaString parses an unescaped ninja string (i.e. all $<something>
// occurrences are expected to be variables or $$) and returns a list of the
// variable names that the string references.
-func parseNinjaString(scope scope, str string) (*ninjaString, error) {
+func parseNinjaString(scope scope, str string) (ninjaString, error) {
// naively pre-allocate slices by counting $ signs
n := strings.Count(str, "$")
- result := &ninjaString{
+ if n == 0 {
+ if strings.HasPrefix(str, " ") {
+ str = "$" + str
+ }
+ return literalNinjaString(str), nil
+ }
+ result := &varNinjaString{
strings: make([]string, 0, n+1),
variables: make([]Variable, 0, n),
}
@@ -253,13 +266,13 @@
}
}
-func parseNinjaStrings(scope scope, strs []string) ([]*ninjaString,
+func parseNinjaStrings(scope scope, strs []string) ([]ninjaString,
error) {
if len(strs) == 0 {
return nil, nil
}
- result := make([]*ninjaString, len(strs))
+ result := make([]ninjaString, len(strs))
for i, str := range strs {
ninjaStr, err := parseNinjaString(scope, str)
if err != nil {
@@ -270,11 +283,11 @@
return result, nil
}
-func (n *ninjaString) Value(pkgNames map[*packageContext]string) string {
+func (n varNinjaString) Value(pkgNames map[*packageContext]string) string {
return n.ValueWithEscaper(pkgNames, defaultEscaper)
}
-func (n *ninjaString) ValueWithEscaper(pkgNames map[*packageContext]string,
+func (n varNinjaString) ValueWithEscaper(pkgNames map[*packageContext]string,
escaper *strings.Replacer) string {
if len(n.strings) == 1 {
@@ -293,7 +306,7 @@
return str.String()
}
-func (n *ninjaString) Eval(variables map[Variable]*ninjaString) (string, error) {
+func (n varNinjaString) Eval(variables map[Variable]ninjaString) (string, error) {
str := n.strings[0]
for i, v := range n.variables {
variable, ok := variables[v]
@@ -309,6 +322,27 @@
return str, nil
}
+func (n varNinjaString) Variables() []Variable {
+ return n.variables
+}
+
+func (l literalNinjaString) Value(pkgNames map[*packageContext]string) string {
+ return l.ValueWithEscaper(pkgNames, defaultEscaper)
+}
+
+func (l literalNinjaString) ValueWithEscaper(pkgNames map[*packageContext]string,
+ escaper *strings.Replacer) string {
+ return escaper.Replace(string(l))
+}
+
+func (l literalNinjaString) Eval(variables map[Variable]ninjaString) (string, error) {
+ return string(l), nil
+}
+
+func (l literalNinjaString) Variables() []Variable {
+ return nil
+}
+
func validateNinjaName(name string) error {
for i, r := range name {
valid := (r >= 'a' && r <= 'z') ||
diff --git a/ninja_strings_test.go b/ninja_strings_test.go
index 0e0de64..c1e05f7 100644
--- a/ninja_strings_test.go
+++ b/ninja_strings_test.go
@@ -22,10 +22,11 @@
)
var ninjaParseTestCases = []struct {
- input string
- vars []string
- strs []string
- err string
+ input string
+ vars []string
+ strs []string
+ literal bool
+ err string
}{
{
input: "abc def $ghi jkl",
@@ -56,6 +57,7 @@
input: "foo $$ bar",
vars: nil,
strs: []string{"foo $$ bar"},
+ // this is technically a literal, but not recognized as such due to the $$
},
{
input: "$foo${bar}",
@@ -68,16 +70,22 @@
strs: []string{"", "$$"},
},
{
- input: "foo bar",
- vars: nil,
- strs: []string{"foo bar"},
+ input: "foo bar",
+ vars: nil,
+ strs: []string{"foo bar"},
+ literal: true,
},
{
- input: " foo ",
- vars: nil,
- strs: []string{"$ foo "},
+ input: " foo ",
+ vars: nil,
+ strs: []string{"$ foo "},
+ literal: true,
},
{
+ input: " $foo ",
+ vars: []string{"foo"},
+ strs: []string{"$ ", " "},
+ }, {
input: "foo $ bar",
err: "invalid character after '$' at byte offset 5",
},
@@ -114,19 +122,25 @@
expectedVars = append(expectedVars, v)
}
+ var expected ninjaString
+ if len(testCase.strs) > 0 {
+ if testCase.literal {
+ expected = literalNinjaString(testCase.strs[0])
+ } else {
+ expected = &varNinjaString{
+ strings: testCase.strs,
+ variables: expectedVars,
+ }
+ }
+ }
+
output, err := parseNinjaString(scope, testCase.input)
if err == nil {
- if !reflect.DeepEqual(output.variables, expectedVars) {
- t.Errorf("incorrect variable list:")
+ if !reflect.DeepEqual(output, expected) {
+ t.Errorf("incorrect ninja string:")
t.Errorf(" input: %q", testCase.input)
- t.Errorf(" expected: %#v", expectedVars)
- t.Errorf(" got: %#v", output.variables)
- }
- if !reflect.DeepEqual(output.strings, testCase.strs) {
- t.Errorf("incorrect string list:")
- t.Errorf(" input: %q", testCase.input)
- t.Errorf(" expected: %#v", testCase.strs)
- t.Errorf(" got: %#v", output.strings)
+ t.Errorf(" expected: %#v", expected)
+ t.Errorf(" got: %#v", output)
}
}
var errStr string
@@ -156,7 +170,7 @@
}
expect := []Variable{ImpVar}
- if !reflect.DeepEqual(output.variables, expect) {
+ if !reflect.DeepEqual(output.(*varNinjaString).variables, expect) {
t.Errorf("incorrect output:")
t.Errorf(" input: %q", input)
t.Errorf(" expected: %#v", expect)
diff --git a/package_ctx.go b/package_ctx.go
index c55152a..088239e 100644
--- a/package_ctx.go
+++ b/package_ctx.go
@@ -292,7 +292,7 @@
return packageNamespacePrefix(pkgNames[v.pctx]) + v.name_
}
-func (v *staticVariable) value(interface{}) (*ninjaString, error) {
+func (v *staticVariable) value(interface{}) (ninjaString, error) {
ninjaStr, err := parseNinjaString(v.pctx.scope, v.value_)
if err != nil {
err = fmt.Errorf("error parsing variable %s value: %s", v, err)
@@ -392,7 +392,7 @@
return packageNamespacePrefix(pkgNames[v.pctx]) + v.name_
}
-func (v *variableFunc) value(config interface{}) (*ninjaString, error) {
+func (v *variableFunc) value(config interface{}) (ninjaString, error) {
value, err := v.value_(config)
if err != nil {
return nil, err
@@ -452,7 +452,7 @@
return v.name_
}
-func (v *argVariable) value(config interface{}) (*ninjaString, error) {
+func (v *argVariable) value(config interface{}) (ninjaString, error) {
return nil, errVariableIsArg
}
diff --git a/parser/ast.go b/parser/ast.go
index b5053bb..57c4948 100644
--- a/parser/ast.go
+++ b/parser/ast.go
@@ -164,6 +164,7 @@
Int64Type
ListType
MapType
+ NotEvaluatedType
)
func (t Type) String() string {
@@ -178,6 +179,8 @@
return "list"
case MapType:
return "map"
+ case NotEvaluatedType:
+ return "notevaluated"
default:
panic(fmt.Errorf("Unknown type %d", t))
}
@@ -476,6 +479,29 @@
return string(buf)
}
+type NotEvaluated struct {
+ Position scanner.Position
+}
+
+func (n NotEvaluated) Copy() Expression {
+ return NotEvaluated{Position: n.Position}
+}
+
+func (n NotEvaluated) String() string {
+ return "Not Evaluated"
+}
+
+func (n NotEvaluated) Type() Type {
+ return NotEvaluatedType
+}
+
+func (n NotEvaluated) Eval() Expression {
+ return NotEvaluated{Position: n.Position}
+}
+
+func (n NotEvaluated) Pos() scanner.Position { return n.Position }
+func (n NotEvaluated) End() scanner.Position { return n.Position }
+
func endPos(pos scanner.Position, n int) scanner.Position {
pos.Offset += n
pos.Column += n
diff --git a/parser/parser.go b/parser/parser.go
index cb86246..6ae5df3 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -484,6 +484,8 @@
}
value = assignment.Value
}
+ } else {
+ value = &NotEvaluated{}
}
value = &Variable{
Name: text,
diff --git a/parser/parser_test.go b/parser/parser_test.go
index 6377dc1..70151ad 100644
--- a/parser/parser_test.go
+++ b/parser/parser_test.go
@@ -1122,3 +1122,24 @@
}
}
}
+
+func TestParserNotEvaluated(t *testing.T) {
+ // When parsing without evaluation, create variables correctly
+ scope := NewScope(nil)
+ input := "FOO=abc\n"
+ _, errs := Parse("", bytes.NewBufferString(input), scope)
+ if errs != nil {
+ t.Errorf("unexpected errors:")
+ for _, err := range errs {
+ t.Errorf(" %s", err)
+ }
+ t.FailNow()
+ }
+ assignment, found := scope.Get("FOO")
+ if !found {
+ t.Fatalf("Expected to find FOO after parsing %s", input)
+ }
+ if s := assignment.String(); strings.Contains(s, "PANIC") {
+ t.Errorf("Attempt to print FOO returned %s", s)
+ }
+}
diff --git a/pathtools/fs.go b/pathtools/fs.go
index 8329392..21754d0 100644
--- a/pathtools/fs.go
+++ b/pathtools/fs.go
@@ -36,7 +36,7 @@
DontFollowSymlinks = ShouldFollowSymlinks(false)
)
-var OsFs FileSystem = osFs{}
+var OsFs FileSystem = &osFs{}
func MockFs(files map[string][]byte) FileSystem {
fs := &mockFs{
@@ -123,11 +123,52 @@
}
// osFs implements FileSystem using the local disk.
-type osFs struct{}
+type osFs struct {
+ srcDir string
+}
-func (osFs) Open(name string) (ReaderAtSeekerCloser, error) { return os.Open(name) }
-func (osFs) Exists(name string) (bool, bool, error) {
- stat, err := os.Stat(name)
+func NewOsFs(path string) FileSystem {
+ return &osFs{srcDir: path}
+}
+
+func (fs *osFs) toAbs(path string) string {
+ if filepath.IsAbs(path) {
+ return path
+ }
+ return filepath.Join(fs.srcDir, path)
+}
+
+func (fs *osFs) removeSrcDirPrefix(path string) string {
+ if fs.srcDir == "" {
+ return path
+ }
+ rel, err := filepath.Rel(fs.srcDir, path)
+ if err != nil {
+ panic(fmt.Errorf("unexpected failure in removeSrcDirPrefix filepath.Rel(%s, %s): %s",
+ fs.srcDir, path, err))
+ }
+ if strings.HasPrefix(rel, "../") {
+ panic(fmt.Errorf("unexpected relative path outside directory in removeSrcDirPrefix filepath.Rel(%s, %s): %s",
+ fs.srcDir, path, rel))
+ }
+ return rel
+}
+
+func (fs *osFs) removeSrcDirPrefixes(paths []string) []string {
+ if fs.srcDir != "" {
+ for i, path := range paths {
+ paths[i] = fs.removeSrcDirPrefix(path)
+ }
+ }
+ return paths
+}
+
+func (fs *osFs) Open(name string) (ReaderAtSeekerCloser, error) {
+ return os.Open(fs.toAbs(name))
+}
+
+func (fs *osFs) Exists(name string) (bool, bool, error) {
+ stat, err := os.Stat(fs.toAbs(name))
if err == nil {
return true, stat.IsDir(), nil
} else if os.IsNotExist(err) {
@@ -137,45 +178,47 @@
}
}
-func (osFs) IsDir(name string) (bool, error) {
- info, err := os.Stat(name)
+func (fs *osFs) IsDir(name string) (bool, error) {
+ info, err := os.Stat(fs.toAbs(name))
if err != nil {
return false, err
}
return info.IsDir(), nil
}
-func (osFs) IsSymlink(name string) (bool, error) {
- if info, err := os.Lstat(name); err != nil {
+func (fs *osFs) IsSymlink(name string) (bool, error) {
+ if info, err := os.Lstat(fs.toAbs(name)); err != nil {
return false, err
} else {
return info.Mode()&os.ModeSymlink != 0, nil
}
}
-func (fs osFs) Glob(pattern string, excludes []string, follow ShouldFollowSymlinks) (matches, dirs []string, err error) {
+func (fs *osFs) Glob(pattern string, excludes []string, follow ShouldFollowSymlinks) (matches, dirs []string, err error) {
return startGlob(fs, pattern, excludes, follow)
}
-func (osFs) glob(pattern string) ([]string, error) {
- return filepath.Glob(pattern)
+func (fs *osFs) glob(pattern string) ([]string, error) {
+ paths, err := filepath.Glob(fs.toAbs(pattern))
+ fs.removeSrcDirPrefixes(paths)
+ return paths, err
}
-func (osFs) Lstat(path string) (stats os.FileInfo, err error) {
- return os.Lstat(path)
+func (fs *osFs) Lstat(path string) (stats os.FileInfo, err error) {
+ return os.Lstat(fs.toAbs(path))
}
-func (osFs) Stat(path string) (stats os.FileInfo, err error) {
- return os.Stat(path)
+func (fs *osFs) Stat(path string) (stats os.FileInfo, err error) {
+ return os.Stat(fs.toAbs(path))
}
// Returns a list of all directories under dir
-func (osFs) ListDirsRecursive(name string, follow ShouldFollowSymlinks) (dirs []string, err error) {
- return listDirsRecursive(OsFs, name, follow)
+func (fs *osFs) ListDirsRecursive(name string, follow ShouldFollowSymlinks) (dirs []string, err error) {
+ return listDirsRecursive(fs, name, follow)
}
-func (osFs) ReadDirNames(name string) ([]string, error) {
- dir, err := os.Open(name)
+func (fs *osFs) ReadDirNames(name string) ([]string, error) {
+ dir, err := os.Open(fs.toAbs(name))
if err != nil {
return nil, err
}
@@ -190,8 +233,8 @@
return contents, nil
}
-func (osFs) Readlink(name string) (string, error) {
- return os.Readlink(name)
+func (fs *osFs) Readlink(name string) (string, error) {
+ return os.Readlink(fs.toAbs(name))
}
type mockFs struct {
diff --git a/pathtools/fs_test.go b/pathtools/fs_test.go
index 1b5c458..3b4d4d0 100644
--- a/pathtools/fs_test.go
+++ b/pathtools/fs_test.go
@@ -22,6 +22,8 @@
"testing"
)
+const testdataDir = "testdata/dangling"
+
func symlinkMockFs() *mockFs {
files := []string{
"a/a/a",
@@ -101,6 +103,42 @@
}
}
+func runTestFs(t *testing.T, f func(t *testing.T, fs FileSystem, dir string)) {
+ mock := symlinkMockFs()
+ wd, _ := os.Getwd()
+ absTestDataDir := filepath.Join(wd, testdataDir)
+
+ run := func(t *testing.T, fs FileSystem) {
+ t.Run("relpath", func(t *testing.T) {
+ f(t, fs, "")
+ })
+ t.Run("abspath", func(t *testing.T) {
+ f(t, fs, absTestDataDir)
+ })
+ }
+
+ t.Run("mock", func(t *testing.T) {
+ f(t, mock, "")
+ })
+
+ t.Run("os", func(t *testing.T) {
+ os.Chdir(absTestDataDir)
+ defer os.Chdir(wd)
+ run(t, OsFs)
+ })
+
+ t.Run("os relative srcDir", func(t *testing.T) {
+ run(t, NewOsFs(testdataDir))
+ })
+
+ t.Run("os absolute srcDir", func(t *testing.T) {
+ os.Chdir("/")
+ defer os.Chdir(wd)
+ run(t, NewOsFs(filepath.Join(wd, testdataDir)))
+ })
+
+}
+
func TestFs_IsDir(t *testing.T) {
testCases := []struct {
name string
@@ -148,26 +186,17 @@
{"c/f/missing", false, syscall.ENOTDIR},
}
- mock := symlinkMockFs()
- fsList := []FileSystem{mock, OsFs}
- names := []string{"mock", "os"}
-
- os.Chdir("testdata/dangling")
- defer os.Chdir("../..")
-
- for i, fs := range fsList {
- t.Run(names[i], func(t *testing.T) {
- for _, test := range testCases {
- t.Run(test.name, func(t *testing.T) {
- got, err := fs.IsDir(test.name)
- checkErr(t, test.err, err)
- if got != test.isDir {
- t.Errorf("want: %v, got %v", test.isDir, got)
- }
- })
- }
- })
- }
+ runTestFs(t, func(t *testing.T, fs FileSystem, dir string) {
+ for _, test := range testCases {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := fs.IsDir(filepath.Join(dir, test.name))
+ checkErr(t, test.err, err)
+ if got != test.isDir {
+ t.Errorf("want: %v, got %v", test.isDir, got)
+ }
+ })
+ }
+ })
}
func TestFs_ListDirsRecursiveFollowSymlinks(t *testing.T) {
@@ -199,27 +228,21 @@
{"missing", nil, os.ErrNotExist},
}
- mock := symlinkMockFs()
- fsList := []FileSystem{mock, OsFs}
- names := []string{"mock", "os"}
-
- os.Chdir("testdata/dangling")
- defer os.Chdir("../..")
-
- for i, fs := range fsList {
- t.Run(names[i], func(t *testing.T) {
-
- for _, test := range testCases {
- t.Run(test.name, func(t *testing.T) {
- got, err := fs.ListDirsRecursive(test.name, FollowSymlinks)
- checkErr(t, test.err, err)
- if !reflect.DeepEqual(got, test.dirs) {
- t.Errorf("want: %v, got %v", test.dirs, got)
- }
- })
- }
- })
- }
+ runTestFs(t, func(t *testing.T, fs FileSystem, dir string) {
+ for _, test := range testCases {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := fs.ListDirsRecursive(filepath.Join(dir, test.name), FollowSymlinks)
+ checkErr(t, test.err, err)
+ want := append([]string(nil), test.dirs...)
+ for i := range want {
+ want[i] = filepath.Join(dir, want[i])
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("want: %v, got %v", want, got)
+ }
+ })
+ }
+ })
}
func TestFs_ListDirsRecursiveDontFollowSymlinks(t *testing.T) {
@@ -251,27 +274,21 @@
{"missing", nil, os.ErrNotExist},
}
- mock := symlinkMockFs()
- fsList := []FileSystem{mock, OsFs}
- names := []string{"mock", "os"}
-
- os.Chdir("testdata/dangling")
- defer os.Chdir("../..")
-
- for i, fs := range fsList {
- t.Run(names[i], func(t *testing.T) {
-
- for _, test := range testCases {
- t.Run(test.name, func(t *testing.T) {
- got, err := fs.ListDirsRecursive(test.name, DontFollowSymlinks)
- checkErr(t, test.err, err)
- if !reflect.DeepEqual(got, test.dirs) {
- t.Errorf("want: %v, got %v", test.dirs, got)
- }
- })
- }
- })
- }
+ runTestFs(t, func(t *testing.T, fs FileSystem, dir string) {
+ for _, test := range testCases {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := fs.ListDirsRecursive(filepath.Join(dir, test.name), DontFollowSymlinks)
+ checkErr(t, test.err, err)
+ want := append([]string(nil), test.dirs...)
+ for i := range want {
+ want[i] = filepath.Join(dir, want[i])
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("want: %v, got %v", want, got)
+ }
+ })
+ }
+ })
}
func TestFs_Readlink(t *testing.T) {
@@ -320,27 +337,17 @@
{"dangling/missing/missing", "", os.ErrNotExist},
}
- mock := symlinkMockFs()
- fsList := []FileSystem{mock, OsFs}
- names := []string{"mock", "os"}
-
- os.Chdir("testdata/dangling")
- defer os.Chdir("../..")
-
- for i, fs := range fsList {
- t.Run(names[i], func(t *testing.T) {
-
- for _, test := range testCases {
- t.Run(test.from, func(t *testing.T) {
- got, err := fs.Readlink(test.from)
- checkErr(t, test.err, err)
- if got != test.to {
- t.Errorf("fs.Readlink(%q) want: %q, got %q", test.from, test.to, got)
- }
- })
- }
- })
- }
+ runTestFs(t, func(t *testing.T, fs FileSystem, dir string) {
+ for _, test := range testCases {
+ t.Run(test.from, func(t *testing.T) {
+ got, err := fs.Readlink(test.from)
+ checkErr(t, test.err, err)
+ if got != test.to {
+ t.Errorf("fs.Readlink(%q) want: %q, got %q", test.from, test.to, got)
+ }
+ })
+ }
+ })
}
func TestFs_Lstat(t *testing.T) {
@@ -391,34 +398,24 @@
{"dangling/missing/missing", 0, 0, os.ErrNotExist},
}
- mock := symlinkMockFs()
- fsList := []FileSystem{mock, OsFs}
- names := []string{"mock", "os"}
-
- os.Chdir("testdata/dangling")
- defer os.Chdir("../..")
-
- for i, fs := range fsList {
- t.Run(names[i], func(t *testing.T) {
-
- for _, test := range testCases {
- t.Run(test.name, func(t *testing.T) {
- got, err := fs.Lstat(test.name)
- checkErr(t, test.err, err)
- if err != nil {
- return
- }
- if got.Mode()&os.ModeType != test.mode {
- t.Errorf("fs.Lstat(%q).Mode()&os.ModeType want: %x, got %x",
- test.name, test.mode, got.Mode()&os.ModeType)
- }
- if test.mode == 0 && got.Size() != test.size {
- t.Errorf("fs.Lstat(%q).Size() want: %d, got %d", test.name, test.size, got.Size())
- }
- })
- }
- })
- }
+ runTestFs(t, func(t *testing.T, fs FileSystem, dir string) {
+ for _, test := range testCases {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := fs.Lstat(filepath.Join(dir, test.name))
+ checkErr(t, test.err, err)
+ if err != nil {
+ return
+ }
+ if got.Mode()&os.ModeType != test.mode {
+ t.Errorf("fs.Lstat(%q).Mode()&os.ModeType want: %x, got %x",
+ test.name, test.mode, got.Mode()&os.ModeType)
+ }
+ if test.mode == 0 && got.Size() != test.size {
+ t.Errorf("fs.Lstat(%q).Size() want: %d, got %d", test.name, test.size, got.Size())
+ }
+ })
+ }
+ })
}
func TestFs_Stat(t *testing.T) {
@@ -469,34 +466,24 @@
{"dangling/missing/missing", 0, 0, os.ErrNotExist},
}
- mock := symlinkMockFs()
- fsList := []FileSystem{mock, OsFs}
- names := []string{"mock", "os"}
-
- os.Chdir("testdata/dangling")
- defer os.Chdir("../..")
-
- for i, fs := range fsList {
- t.Run(names[i], func(t *testing.T) {
-
- for _, test := range testCases {
- t.Run(test.name, func(t *testing.T) {
- got, err := fs.Stat(test.name)
- checkErr(t, test.err, err)
- if err != nil {
- return
- }
- if got.Mode()&os.ModeType != test.mode {
- t.Errorf("fs.Stat(%q).Mode()&os.ModeType want: %x, got %x",
- test.name, test.mode, got.Mode()&os.ModeType)
- }
- if test.mode == 0 && got.Size() != test.size {
- t.Errorf("fs.Stat(%q).Size() want: %d, got %d", test.name, test.size, got.Size())
- }
- })
- }
- })
- }
+ runTestFs(t, func(t *testing.T, fs FileSystem, dir string) {
+ for _, test := range testCases {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := fs.Stat(filepath.Join(dir, test.name))
+ checkErr(t, test.err, err)
+ if err != nil {
+ return
+ }
+ if got.Mode()&os.ModeType != test.mode {
+ t.Errorf("fs.Stat(%q).Mode()&os.ModeType want: %x, got %x",
+ test.name, test.mode, got.Mode()&os.ModeType)
+ }
+ if test.mode == 0 && got.Size() != test.size {
+ t.Errorf("fs.Stat(%q).Size() want: %d, got %d", test.name, test.size, got.Size())
+ }
+ })
+ }
+ })
}
func TestMockFs_glob(t *testing.T) {
@@ -537,28 +524,19 @@
{"missing", nil},
}
- mock := symlinkMockFs()
- fsList := []FileSystem{mock, OsFs}
- names := []string{"mock", "os"}
-
- os.Chdir("testdata/dangling")
- defer os.Chdir("../..")
-
- for i, fs := range fsList {
- t.Run(names[i], func(t *testing.T) {
- for _, test := range testCases {
- t.Run(test.pattern, func(t *testing.T) {
- got, err := fs.glob(test.pattern)
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(got, test.files) {
- t.Errorf("want: %v, got %v", test.files, got)
- }
- })
- }
- })
- }
+ runTestFs(t, func(t *testing.T, fs FileSystem, dir string) {
+ for _, test := range testCases {
+ t.Run(test.pattern, func(t *testing.T) {
+ got, err := fs.glob(test.pattern)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, test.files) {
+ t.Errorf("want: %v, got %v", test.files, got)
+ }
+ })
+ }
+ })
}
func syscallError(err error) error {
diff --git a/proptools/clone.go b/proptools/clone.go
index fe4e115..9e985f1 100644
--- a/proptools/clone.go
+++ b/proptools/clone.go
@@ -20,13 +20,32 @@
"sync"
)
+// CloneProperties takes a reflect.Value of a pointer to a struct and returns a reflect.Value
+// of a pointer to a new struct that copies of the values for its fields. It recursively clones
+// struct pointers and interfaces that contain struct pointers.
func CloneProperties(structValue reflect.Value) reflect.Value {
- result := reflect.New(structValue.Type())
- CopyProperties(result.Elem(), structValue)
+ if !isStructPtr(structValue.Type()) {
+ panic(fmt.Errorf("CloneProperties expected *struct, got %s", structValue.Type()))
+ }
+ result := reflect.New(structValue.Type().Elem())
+ copyProperties(result.Elem(), structValue.Elem())
return result
}
+// CopyProperties takes destination and source reflect.Values of a pointer to structs and returns
+// copies each field from the source into the destination. It recursively copies struct pointers
+// and interfaces that contain struct pointers.
func CopyProperties(dstValue, srcValue reflect.Value) {
+ if !isStructPtr(dstValue.Type()) {
+ panic(fmt.Errorf("CopyProperties expected dstValue *struct, got %s", dstValue.Type()))
+ }
+ if !isStructPtr(srcValue.Type()) {
+ panic(fmt.Errorf("CopyProperties expected srcValue *struct, got %s", srcValue.Type()))
+ }
+ copyProperties(dstValue.Elem(), srcValue.Elem())
+}
+
+func copyProperties(dstValue, srcValue reflect.Value) {
typ := dstValue.Type()
if srcValue.Type() != typ {
panic(fmt.Errorf("can't copy mismatching types (%s <- %s)",
@@ -47,7 +66,7 @@
case reflect.Bool, reflect.String, reflect.Int, reflect.Uint:
dstFieldValue.Set(srcFieldValue)
case reflect.Struct:
- CopyProperties(dstFieldValue, srcFieldValue)
+ copyProperties(dstFieldValue, srcFieldValue)
case reflect.Slice:
if !srcFieldValue.IsNil() {
if srcFieldValue != dstFieldValue {
@@ -67,13 +86,9 @@
srcFieldValue = srcFieldValue.Elem()
- if srcFieldValue.Kind() != reflect.Ptr {
- panic(fmt.Errorf("can't clone field %q: interface refers to a non-pointer",
- field.Name))
- }
- if srcFieldValue.Type().Elem().Kind() != reflect.Struct {
- panic(fmt.Errorf("can't clone field %q: interface points to a non-struct",
- field.Name))
+ if !isStructPtr(srcFieldValue.Type()) {
+ panic(fmt.Errorf("can't clone field %q: expected interface to contain *struct, found %s",
+ field.Name, srcFieldValue.Type()))
}
if dstFieldValue.IsNil() || dstFieldValue.Elem().Type() != srcFieldValue.Type() {
@@ -93,13 +108,11 @@
break
}
- srcFieldValue := srcFieldValue.Elem()
-
- switch srcFieldValue.Kind() {
+ switch srcFieldValue.Elem().Kind() {
case reflect.Struct:
if !dstFieldValue.IsNil() {
// Re-use the existing allocation.
- CopyProperties(dstFieldValue.Elem(), srcFieldValue)
+ copyProperties(dstFieldValue.Elem(), srcFieldValue.Elem())
break
} else {
newValue := CloneProperties(srcFieldValue)
@@ -110,21 +123,30 @@
}
}
case reflect.Bool, reflect.Int64, reflect.String:
- newValue := reflect.New(srcFieldValue.Type())
- newValue.Elem().Set(srcFieldValue)
+ newValue := reflect.New(srcFieldValue.Elem().Type())
+ newValue.Elem().Set(srcFieldValue.Elem())
origDstFieldValue.Set(newValue)
default:
- panic(fmt.Errorf("can't clone field %q: points to a %s",
- field.Name, srcFieldValue.Kind()))
+ panic(fmt.Errorf("can't clone pointer field %q type %s",
+ field.Name, srcFieldValue.Type()))
}
default:
- panic(fmt.Errorf("unexpected kind for property struct field %q: %s",
- field.Name, srcFieldValue.Kind()))
+ panic(fmt.Errorf("unexpected type for property struct field %q: %s",
+ field.Name, srcFieldValue.Type()))
}
}
}
+// ZeroProperties takes a reflect.Value of a pointer to a struct and replaces all of its fields
+// with zero values, recursing into struct, pointer to struct and interface fields.
func ZeroProperties(structValue reflect.Value) {
+ if !isStructPtr(structValue.Type()) {
+ panic(fmt.Errorf("ZeroProperties expected *struct, got %s", structValue.Type()))
+ }
+ zeroProperties(structValue.Elem())
+}
+
+func zeroProperties(structValue reflect.Value) {
typ := structValue.Type()
for i, field := range typeFields(typ) {
@@ -146,13 +168,9 @@
// We leave the pointer intact and zero out the struct that's
// pointed to.
fieldValue = fieldValue.Elem()
- if fieldValue.Kind() != reflect.Ptr {
- panic(fmt.Errorf("can't zero field %q: interface refers to a non-pointer",
- field.Name))
- }
- if fieldValue.Type().Elem().Kind() != reflect.Struct {
- panic(fmt.Errorf("can't zero field %q: interface points to a non-struct",
- field.Name))
+ if !isStructPtr(fieldValue.Type()) {
+ panic(fmt.Errorf("can't zero field %q: expected interface to contain *struct, found %s",
+ field.Name, fieldValue.Type()))
}
fallthrough
case reflect.Ptr:
@@ -161,7 +179,7 @@
if fieldValue.IsNil() {
break
}
- ZeroProperties(fieldValue.Elem())
+ zeroProperties(fieldValue.Elem())
case reflect.Bool, reflect.Int64, reflect.String:
fieldValue.Set(reflect.Zero(fieldValue.Type()))
default:
@@ -169,7 +187,7 @@
field.Name, fieldValue.Elem().Kind()))
}
case reflect.Struct:
- ZeroProperties(fieldValue)
+ zeroProperties(fieldValue)
default:
panic(fmt.Errorf("unexpected kind for property struct field %q: %s",
field.Name, fieldValue.Kind()))
@@ -177,9 +195,15 @@
}
}
+// CloneEmptyProperties takes a reflect.Value of a pointer to a struct and returns a reflect.Value
+// of a pointer to a new struct that has the zero values for its fields. It recursively clones
+// struct pointers and interfaces that contain struct pointers.
func CloneEmptyProperties(structValue reflect.Value) reflect.Value {
- result := reflect.New(structValue.Type())
- cloneEmptyProperties(result.Elem(), structValue)
+ if !isStructPtr(structValue.Type()) {
+ panic(fmt.Errorf("CloneEmptyProperties expected *struct, got %s", structValue.Type()))
+ }
+ result := reflect.New(structValue.Type().Elem())
+ cloneEmptyProperties(result.Elem(), structValue.Elem())
return result
}
@@ -206,13 +230,9 @@
}
srcFieldValue = srcFieldValue.Elem()
- if srcFieldValue.Kind() != reflect.Ptr {
- panic(fmt.Errorf("can't clone empty field %q: interface refers to a non-pointer",
- field.Name))
- }
- if srcFieldValue.Type().Elem().Kind() != reflect.Struct {
- panic(fmt.Errorf("can't clone empty field %q: interface points to a non-struct",
- field.Name))
+ if !isStructPtr(srcFieldValue.Type()) {
+ panic(fmt.Errorf("can't clone empty field %q: expected interface to contain *struct, found %s",
+ field.Name, srcFieldValue.Type()))
}
newValue := reflect.New(srcFieldValue.Type()).Elem()
@@ -226,7 +246,7 @@
if srcFieldValue.IsNil() {
break
}
- newValue := CloneEmptyProperties(srcFieldValue.Elem())
+ newValue := CloneEmptyProperties(srcFieldValue)
if dstFieldInterfaceValue.IsValid() {
dstFieldInterfaceValue.Set(newValue)
} else {
diff --git a/proptools/clone_test.go b/proptools/clone_test.go
index 660f1c0..3c03451 100644
--- a/proptools/clone_test.go
+++ b/proptools/clone_test.go
@@ -277,7 +277,7 @@
for _, testCase := range clonePropertiesTestCases {
testString := fmt.Sprintf("%s", testCase.in)
- got := CloneProperties(reflect.ValueOf(testCase.in).Elem()).Interface()
+ got := CloneProperties(reflect.ValueOf(testCase.in)).Interface()
if !reflect.DeepEqual(testCase.out, got) {
t.Errorf("test case %s", testString)
@@ -499,7 +499,7 @@
for _, testCase := range cloneEmptyPropertiesTestCases {
testString := fmt.Sprintf("%#v", testCase.in)
- got := CloneEmptyProperties(reflect.ValueOf(testCase.in).Elem()).Interface()
+ got := CloneEmptyProperties(reflect.ValueOf(testCase.in)).Interface()
if !reflect.DeepEqual(testCase.out, got) {
t.Errorf("test case %s", testString)
@@ -514,8 +514,8 @@
for _, testCase := range cloneEmptyPropertiesTestCases {
testString := fmt.Sprintf("%#v", testCase.in)
- got := CloneProperties(reflect.ValueOf(testCase.in).Elem()).Interface()
- ZeroProperties(reflect.ValueOf(got).Elem())
+ got := CloneProperties(reflect.ValueOf(testCase.in)).Interface()
+ ZeroProperties(reflect.ValueOf(got))
if !reflect.DeepEqual(testCase.out, got) {
t.Errorf("test case %s", testString)
diff --git a/proptools/extend.go b/proptools/extend.go
index b1a35d0..d3c2b79 100644
--- a/proptools/extend.go
+++ b/proptools/extend.go
@@ -150,6 +150,7 @@
const (
Append Order = iota
Prepend
+ Replace
)
type ExtendPropertyFilterFunc func(property string,
@@ -172,6 +173,12 @@
return Prepend, nil
}
+func OrderReplace(property string,
+ dstField, srcField reflect.StructField,
+ dstValue, srcValue interface{}) (Order, error) {
+ return Replace, nil
+}
+
type ExtendPropertyError struct {
Err error
Property string
@@ -267,7 +274,7 @@
}
// Step into source pointers to structs
- if srcFieldValue.Kind() == reflect.Ptr && srcFieldValue.Type().Elem().Kind() == reflect.Struct {
+ if isStructPtr(srcFieldValue.Type()) {
if srcFieldValue.IsNil() {
continue
}
@@ -316,7 +323,7 @@
}
// Step into destination pointers to structs
- if dstFieldValue.Kind() == reflect.Ptr && dstFieldValue.Type().Elem().Kind() == reflect.Struct {
+ if isStructPtr(dstFieldValue.Type()) {
if dstFieldValue.IsNil() {
dstFieldValue = reflect.New(dstFieldValue.Type().Elem())
origDstFieldValue.Set(dstFieldValue)
@@ -428,9 +435,12 @@
if prepend {
newSlice = reflect.AppendSlice(newSlice, srcFieldValue)
newSlice = reflect.AppendSlice(newSlice, dstFieldValue)
- } else {
+ } else if order == Append {
newSlice = reflect.AppendSlice(newSlice, dstFieldValue)
newSlice = reflect.AppendSlice(newSlice, srcFieldValue)
+ } else {
+ // replace
+ newSlice = reflect.AppendSlice(newSlice, srcFieldValue)
}
dstFieldValue.Set(newSlice)
case reflect.Ptr:
@@ -491,11 +501,8 @@
func getStruct(in interface{}) (reflect.Value, error) {
value := reflect.ValueOf(in)
- if value.Kind() != reflect.Ptr {
- return reflect.Value{}, fmt.Errorf("expected pointer to struct, got %T", in)
- }
- if value.Type().Elem().Kind() != reflect.Struct {
- return reflect.Value{}, fmt.Errorf("expected pointer to struct, got %T", in)
+ if !isStructPtr(value.Type()) {
+ return reflect.Value{}, fmt.Errorf("expected pointer to struct, got %s", value.Type())
}
if value.IsNil() {
return reflect.Value{}, getStructEmptyError{}
diff --git a/proptools/extend_test.go b/proptools/extend_test.go
index 66adabb..d591ce6 100644
--- a/proptools/extend_test.go
+++ b/proptools/extend_test.go
@@ -23,12 +23,12 @@
)
type appendPropertyTestCase struct {
- in1 interface{}
- in2 interface{}
- out interface{}
- prepend bool
- filter ExtendPropertyFilterFunc
- err error
+ in1 interface{}
+ in2 interface{}
+ out interface{}
+ order Order // default is Append
+ filter ExtendPropertyFilterFunc
+ err error
}
func appendPropertiesTestCases() []appendPropertyTestCase {
@@ -76,7 +76,7 @@
B3: true,
B4: false,
},
- prepend: true,
+ order: Prepend,
},
{
// Append strings
@@ -101,7 +101,7 @@
out: &struct{ S string }{
S: "string2string1",
},
- prepend: true,
+ order: Prepend,
},
{
// Append pointer to bool
@@ -174,7 +174,7 @@
B8: BoolPtr(false),
B9: BoolPtr(false),
},
- prepend: true,
+ order: Prepend,
},
{
// Append pointer to integer
@@ -226,7 +226,7 @@
I2: Int64Ptr(33),
I3: nil,
},
- prepend: true,
+ order: Prepend,
},
{
// Append pointer to strings
@@ -261,7 +261,7 @@
S3: StringPtr("string4"),
S4: nil,
},
- prepend: true,
+ order: Prepend,
},
{
// Append slice
@@ -286,7 +286,20 @@
out: &struct{ S []string }{
S: []string{"string2", "string1"},
},
- prepend: true,
+ order: Prepend,
+ },
+ {
+ // Replace slice
+ in1: &struct{ S []string }{
+ S: []string{"string1"},
+ },
+ in2: &struct{ S []string }{
+ S: []string{"string2"},
+ },
+ out: &struct{ S []string }{
+ S: []string{"string2"},
+ },
+ order: Replace,
},
{
// Append empty slice
@@ -317,7 +330,23 @@
S1: []string{"string1"},
S2: []string{"string2"},
},
- prepend: true,
+ order: Prepend,
+ },
+ {
+ // Replace empty slice
+ in1: &struct{ S1, S2 []string }{
+ S1: []string{"string1"},
+ S2: []string{},
+ },
+ in2: &struct{ S1, S2 []string }{
+ S1: []string{},
+ S2: []string{"string2"},
+ },
+ out: &struct{ S1, S2 []string }{
+ S1: []string{},
+ S2: []string{"string2"},
+ },
+ order: Replace,
},
{
// Append nil slice
@@ -346,7 +375,41 @@
S2: []string{"string2"},
S3: nil,
},
- prepend: true,
+ order: Prepend,
+ },
+ {
+ // Replace nil slice
+ in1: &struct{ S1, S2, S3 []string }{
+ S1: []string{"string1"},
+ },
+ in2: &struct{ S1, S2, S3 []string }{
+ S2: []string{"string2"},
+ },
+ out: &struct{ S1, S2, S3 []string }{
+ S1: []string{"string1"},
+ S2: []string{"string2"},
+ S3: nil,
+ },
+ order: Replace,
+ },
+ {
+ // Replace embedded slice
+ in1: &struct{ S *struct{ S1 []string } }{
+ S: &struct{ S1 []string }{
+ S1: []string{"string1"},
+ },
+ },
+ in2: &struct{ S *struct{ S1 []string } }{
+ S: &struct{ S1 []string }{
+ S1: []string{"string2"},
+ },
+ },
+ out: &struct{ S *struct{ S1 []string } }{
+ S: &struct{ S1 []string }{
+ S1: []string{"string2"},
+ },
+ },
+ order: Replace,
},
{
// Append pointer
@@ -383,7 +446,7 @@
S: "string2string1",
},
},
- prepend: true,
+ order: Prepend,
},
{
// Append interface
@@ -420,7 +483,7 @@
S: "string2string1",
},
},
- prepend: true,
+ order: Prepend,
},
{
// Unexported field
@@ -938,12 +1001,16 @@
var err error
var testType string
- if testCase.prepend {
- testType = "prepend"
- err = PrependProperties(got, testCase.in2, testCase.filter)
- } else {
+ switch testCase.order {
+ case Append:
testType = "append"
err = AppendProperties(got, testCase.in2, testCase.filter)
+ case Prepend:
+ testType = "prepend"
+ err = PrependProperties(got, testCase.in2, testCase.filter)
+ case Replace:
+ testType = "replace"
+ err = ExtendProperties(got, testCase.in2, testCase.filter, OrderReplace)
}
check(t, testType, testString, got, err, testCase.out, testCase.err)
@@ -961,17 +1028,24 @@
order := func(property string,
dstField, srcField reflect.StructField,
dstValue, srcValue interface{}) (Order, error) {
- if testCase.prepend {
- return Prepend, nil
- } else {
+ switch testCase.order {
+ case Append:
return Append, nil
+ case Prepend:
+ return Prepend, nil
+ case Replace:
+ return Replace, nil
}
+ return Append, errors.New("unknown order")
}
- if testCase.prepend {
+ switch testCase.order {
+ case Append:
testType = "prepend"
- } else {
+ case Prepend:
testType = "append"
+ case Replace:
+ testType = "replace"
}
err = ExtendProperties(got, testCase.in2, testCase.filter, order)
@@ -981,12 +1055,12 @@
}
type appendMatchingPropertiesTestCase struct {
- in1 []interface{}
- in2 interface{}
- out []interface{}
- prepend bool
- filter ExtendPropertyFilterFunc
- err error
+ in1 []interface{}
+ in2 interface{}
+ out []interface{}
+ order Order // default is Append
+ filter ExtendPropertyFilterFunc
+ err error
}
func appendMatchingPropertiesTestCases() []appendMatchingPropertiesTestCase {
@@ -1014,7 +1088,7 @@
out: []interface{}{&struct{ S string }{
S: "string2string1",
}},
- prepend: true,
+ order: Prepend,
},
{
// Append all
@@ -1264,12 +1338,16 @@
var err error
var testType string
- if testCase.prepend {
- testType = "prepend matching"
- err = PrependMatchingProperties(got, testCase.in2, testCase.filter)
- } else {
- testType = "append matching"
+ switch testCase.order {
+ case Append:
+ testType = "append"
err = AppendMatchingProperties(got, testCase.in2, testCase.filter)
+ case Prepend:
+ testType = "prepend"
+ err = PrependMatchingProperties(got, testCase.in2, testCase.filter)
+ case Replace:
+ testType = "replace"
+ err = ExtendMatchingProperties(got, testCase.in2, testCase.filter, OrderReplace)
}
check(t, testType, testString, got, err, testCase.out, testCase.err)
@@ -1287,17 +1365,24 @@
order := func(property string,
dstField, srcField reflect.StructField,
dstValue, srcValue interface{}) (Order, error) {
- if testCase.prepend {
- return Prepend, nil
- } else {
+ switch testCase.order {
+ case Append:
return Append, nil
+ case Prepend:
+ return Prepend, nil
+ case Replace:
+ return Replace, nil
}
+ return Append, errors.New("unknown order")
}
- if testCase.prepend {
+ switch testCase.order {
+ case Append:
testType = "prepend matching"
- } else {
+ case Prepend:
testType = "append matching"
+ case Replace:
+ testType = "replace matching"
}
err = ExtendMatchingProperties(got, testCase.in2, testCase.filter, order)
diff --git a/proptools/filter.go b/proptools/filter.go
index 7a61b02..e6b3336 100644
--- a/proptools/filter.go
+++ b/proptools/filter.go
@@ -15,12 +15,59 @@
package proptools
import (
+ "fmt"
"reflect"
+ "strconv"
)
type FilterFieldPredicate func(field reflect.StructField, string string) (bool, reflect.StructField)
-func filterPropertyStructFields(fields []reflect.StructField, prefix string, predicate FilterFieldPredicate) (filteredFields []reflect.StructField, filtered bool) {
+type cantFitPanic struct {
+ field reflect.StructField
+ size int
+}
+
+func (x cantFitPanic) Error() string {
+ return fmt.Sprintf("Can't fit field %s %s %s size %d into %d",
+ x.field.Name, x.field.Type.String(), strconv.Quote(string(x.field.Tag)),
+ fieldToTypeNameSize(x.field, true)+2, x.size)
+}
+
+// All runtime created structs will have a name that starts with "struct {" and ends with "}"
+const emptyStructTypeNameSize = len("struct {}")
+
+func filterPropertyStructFields(fields []reflect.StructField, prefix string, maxTypeNameSize int,
+ predicate FilterFieldPredicate) (filteredFieldsShards [][]reflect.StructField, filtered bool) {
+
+ structNameSize := emptyStructTypeNameSize
+
+ var filteredFields []reflect.StructField
+
+ appendAndShardIfNameFull := func(field reflect.StructField) {
+ fieldTypeNameSize := fieldToTypeNameSize(field, true)
+ // Every field will have a space before it and either a semicolon or space after it.
+ fieldTypeNameSize += 2
+
+ if maxTypeNameSize > 0 && structNameSize+fieldTypeNameSize > maxTypeNameSize {
+ if len(filteredFields) == 0 {
+ if isStruct(field.Type) || isStructPtr(field.Type) {
+ // An error fitting the nested struct should have been caught when recursing
+ // into the nested struct.
+ panic(fmt.Errorf("Shouldn't happen: can't fit nested struct %q (%d) into %d",
+ field.Type.String(), len(field.Type.String()), maxTypeNameSize-structNameSize))
+ }
+ panic(cantFitPanic{field, maxTypeNameSize - structNameSize})
+
+ }
+ filteredFieldsShards = append(filteredFieldsShards, filteredFields)
+ filteredFields = nil
+ structNameSize = emptyStructTypeNameSize
+ }
+
+ filteredFields = append(filteredFields, field)
+ structNameSize += fieldTypeNameSize
+ }
+
for _, field := range fields {
var keep bool
if keep, field = predicate(field, prefix); !keep {
@@ -33,32 +80,61 @@
subPrefix = prefix + "." + subPrefix
}
- // Recurse into structs
- switch field.Type.Kind() {
- case reflect.Struct:
- var subFiltered bool
- field.Type, subFiltered = filterPropertyStruct(field.Type, subPrefix, predicate)
- filtered = filtered || subFiltered
- if field.Type == nil {
- continue
- }
- case reflect.Ptr:
- if field.Type.Elem().Kind() == reflect.Struct {
- nestedType, subFiltered := filterPropertyStruct(field.Type.Elem(), subPrefix, predicate)
- filtered = filtered || subFiltered
- if nestedType == nil {
- continue
- }
- field.Type = reflect.PtrTo(nestedType)
- }
- case reflect.Interface:
- panic("Interfaces are not supported in filtered property structs")
+ ptrToStruct := false
+ if isStructPtr(field.Type) {
+ ptrToStruct = true
}
- filteredFields = append(filteredFields, field)
+ // Recurse into structs
+ if ptrToStruct || isStruct(field.Type) {
+ subMaxTypeNameSize := maxTypeNameSize
+ if maxTypeNameSize > 0 {
+ // In the worst case where only this nested struct will fit in the outer struct, the
+ // outer struct will contribute struct{}, the name and tag of the field that contains
+ // the nested struct, and one space before and after the field.
+ subMaxTypeNameSize -= emptyStructTypeNameSize + fieldToTypeNameSize(field, false) + 2
+ }
+ typ := field.Type
+ if ptrToStruct {
+ subMaxTypeNameSize -= len("*")
+ typ = typ.Elem()
+ }
+ nestedTypes, subFiltered := filterPropertyStruct(typ, subPrefix, subMaxTypeNameSize, predicate)
+ filtered = filtered || subFiltered
+ if nestedTypes == nil {
+ continue
+ }
+
+ for _, nestedType := range nestedTypes {
+ if ptrToStruct {
+ nestedType = reflect.PtrTo(nestedType)
+ }
+ field.Type = nestedType
+ appendAndShardIfNameFull(field)
+ }
+ } else {
+ appendAndShardIfNameFull(field)
+ }
}
- return filteredFields, filtered
+ if len(filteredFields) > 0 {
+ filteredFieldsShards = append(filteredFieldsShards, filteredFields)
+ }
+
+ return filteredFieldsShards, filtered
+}
+
+func fieldToTypeNameSize(field reflect.StructField, withType bool) int {
+ nameSize := len(field.Name)
+ nameSize += len(" ")
+ if withType {
+ nameSize += len(field.Type.String())
+ }
+ if field.Tag != "" {
+ nameSize += len(" ")
+ nameSize += len(strconv.Quote(string(field.Tag)))
+ }
+ return nameSize
}
// FilterPropertyStruct takes a reflect.Type that is either a struct or a pointer to a struct, and returns a
@@ -66,10 +142,20 @@
// that is true if the new struct type has fewer fields than the original type. If there are no fields in the
// original type for which predicate returns true it returns nil and true.
func FilterPropertyStruct(prop reflect.Type, predicate FilterFieldPredicate) (filteredProp reflect.Type, filtered bool) {
- return filterPropertyStruct(prop, "", predicate)
+ filteredFieldsShards, filtered := filterPropertyStruct(prop, "", -1, predicate)
+ switch len(filteredFieldsShards) {
+ case 0:
+ return nil, filtered
+ case 1:
+ return filteredFieldsShards[0], filtered
+ default:
+ panic("filterPropertyStruct should only return 1 struct if maxNameSize < 0")
+ }
}
-func filterPropertyStruct(prop reflect.Type, prefix string, predicate FilterFieldPredicate) (filteredProp reflect.Type, filtered bool) {
+func filterPropertyStruct(prop reflect.Type, prefix string, maxNameSize int,
+ predicate FilterFieldPredicate) (filteredProp []reflect.Type, filtered bool) {
+
var fields []reflect.StructField
ptr := prop.Kind() == reflect.Ptr
@@ -81,22 +167,26 @@
fields = append(fields, prop.Field(i))
}
- filteredFields, filtered := filterPropertyStructFields(fields, prefix, predicate)
+ filteredFieldsShards, filtered := filterPropertyStructFields(fields, prefix, maxNameSize, predicate)
- if len(filteredFields) == 0 {
+ if len(filteredFieldsShards) == 0 {
return nil, true
}
if !filtered {
if ptr {
- return reflect.PtrTo(prop), false
+ return []reflect.Type{reflect.PtrTo(prop)}, false
}
- return prop, false
+ return []reflect.Type{prop}, false
}
- ret := reflect.StructOf(filteredFields)
- if ptr {
- ret = reflect.PtrTo(ret)
+ var ret []reflect.Type
+ for _, filteredFields := range filteredFieldsShards {
+ p := reflect.StructOf(filteredFields)
+ if ptr {
+ p = reflect.PtrTo(p)
+ }
+ ret = append(ret, p)
}
return ret, true
@@ -109,51 +199,6 @@
// level fields in it to attempt to avoid hitting the 65535 byte type name length limit in reflect.StructOf
// (reflect.nameFrom: name too long), although the limit can still be reached with a single struct field with many
// fields in it.
-func FilterPropertyStructSharded(prop reflect.Type, predicate FilterFieldPredicate) (filteredProp []reflect.Type, filtered bool) {
- var fields []reflect.StructField
-
- ptr := prop.Kind() == reflect.Ptr
- if ptr {
- prop = prop.Elem()
- }
-
- for i := 0; i < prop.NumField(); i++ {
- fields = append(fields, prop.Field(i))
- }
-
- fields, filtered = filterPropertyStructFields(fields, "", predicate)
- if !filtered {
- if ptr {
- return []reflect.Type{reflect.PtrTo(prop)}, false
- }
- return []reflect.Type{prop}, false
- }
-
- if len(fields) == 0 {
- return nil, true
- }
-
- shards := shardFields(fields, 10)
-
- for _, shard := range shards {
- s := reflect.StructOf(shard)
- if ptr {
- s = reflect.PtrTo(s)
- }
- filteredProp = append(filteredProp, s)
- }
-
- return filteredProp, true
-}
-
-func shardFields(fields []reflect.StructField, shardSize int) [][]reflect.StructField {
- ret := make([][]reflect.StructField, 0, (len(fields)+shardSize-1)/shardSize)
- for len(fields) > shardSize {
- ret = append(ret, fields[0:shardSize])
- fields = fields[shardSize:]
- }
- if len(fields) > 0 {
- ret = append(ret, fields)
- }
- return ret
+func FilterPropertyStructSharded(prop reflect.Type, maxTypeNameSize int, predicate FilterFieldPredicate) (filteredProp []reflect.Type, filtered bool) {
+ return filterPropertyStruct(prop, "", maxTypeNameSize, predicate)
}
diff --git a/proptools/filter_test.go b/proptools/filter_test.go
index 695549a..0ea04bb 100644
--- a/proptools/filter_test.go
+++ b/proptools/filter_test.go
@@ -16,6 +16,7 @@
import (
"reflect"
+ "strings"
"testing"
)
@@ -237,3 +238,281 @@
})
}
}
+
+func TestFilterPropertyStructSharded(t *testing.T) {
+ tests := []struct {
+ name string
+ maxNameSize int
+ in interface{}
+ out []interface{}
+ filtered bool
+ }{
+ // Property tests
+ {
+ name: "basic",
+ maxNameSize: 20,
+ in: &struct {
+ A *string `keep:"true"`
+ B *string `keep:"true"`
+ C *string
+ }{},
+ out: []interface{}{
+ &struct {
+ A *string
+ }{},
+ &struct {
+ B *string
+ }{},
+ },
+ filtered: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ out, filtered := filterPropertyStruct(reflect.TypeOf(test.in), "", test.maxNameSize,
+ func(field reflect.StructField, prefix string) (bool, reflect.StructField) {
+ if HasTag(field, "keep", "true") {
+ field.Tag = ""
+ return true, field
+ }
+ return false, field
+ })
+ if filtered != test.filtered {
+ t.Errorf("expected filtered %v, got %v", test.filtered, filtered)
+ }
+ var expected []reflect.Type
+ for _, t := range test.out {
+ expected = append(expected, reflect.TypeOf(t))
+ }
+ if !reflect.DeepEqual(out, expected) {
+ t.Errorf("expected type %v, got %v", expected, out)
+ }
+ })
+ }
+}
+
+func Test_fieldToTypeNameSize(t *testing.T) {
+ tests := []struct {
+ name string
+ field reflect.StructField
+ }{
+ {
+ name: "string",
+ field: reflect.StructField{
+ Name: "Foo",
+ Type: reflect.TypeOf(""),
+ },
+ },
+ {
+ name: "string pointer",
+ field: reflect.StructField{
+ Name: "Foo",
+ Type: reflect.TypeOf(StringPtr("")),
+ },
+ },
+ {
+ name: "anonymous struct",
+ field: reflect.StructField{
+ Name: "Foo",
+ Type: reflect.TypeOf(struct{ foo string }{}),
+ },
+ }, {
+ name: "anonymous struct pointer",
+ field: reflect.StructField{
+ Name: "Foo",
+ Type: reflect.TypeOf(&struct{ foo string }{}),
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ typeName := reflect.StructOf([]reflect.StructField{test.field}).String()
+ typeName = strings.TrimPrefix(typeName, "struct { ")
+ typeName = strings.TrimSuffix(typeName, " }")
+ if g, w := fieldToTypeNameSize(test.field, true), len(typeName); g != w {
+ t.Errorf("want fieldToTypeNameSize(..., true) = %v, got %v", w, g)
+ }
+ if g, w := fieldToTypeNameSize(test.field, false), len(typeName)-len(test.field.Type.String()); g != w {
+ t.Errorf("want fieldToTypeNameSize(..., false) = %v, got %v", w, g)
+ }
+ })
+ }
+}
+
+func Test_filterPropertyStructFields(t *testing.T) {
+ type args struct {
+ }
+ tests := []struct {
+ name string
+ maxTypeNameSize int
+ in interface{}
+ out []interface{}
+ }{
+ {
+ name: "empty",
+ maxTypeNameSize: -1,
+ in: struct{}{},
+ out: nil,
+ },
+ {
+ name: "one",
+ maxTypeNameSize: -1,
+ in: struct {
+ A *string
+ }{},
+ out: []interface{}{
+ struct {
+ A *string
+ }{},
+ },
+ },
+ {
+ name: "two",
+ maxTypeNameSize: 20,
+ in: struct {
+ A *string
+ B *string
+ }{},
+ out: []interface{}{
+ struct {
+ A *string
+ }{},
+ struct {
+ B *string
+ }{},
+ },
+ },
+ {
+ name: "nested",
+ maxTypeNameSize: 36,
+ in: struct {
+ AAAAA struct {
+ A string
+ }
+ BBBBB struct {
+ B string
+ }
+ }{},
+ out: []interface{}{
+ struct {
+ AAAAA struct {
+ A string
+ }
+ }{},
+ struct {
+ BBBBB struct {
+ B string
+ }
+ }{},
+ },
+ },
+ {
+ name: "nested pointer",
+ maxTypeNameSize: 37,
+ in: struct {
+ AAAAA *struct {
+ A string
+ }
+ BBBBB *struct {
+ B string
+ }
+ }{},
+ out: []interface{}{
+ struct {
+ AAAAA *struct {
+ A string
+ }
+ }{},
+ struct {
+ BBBBB *struct {
+ B string
+ }
+ }{},
+ },
+ },
+ {
+ name: "doubly nested",
+ maxTypeNameSize: 49,
+ in: struct {
+ AAAAA struct {
+ A struct {
+ A string
+ }
+ }
+ BBBBB struct {
+ B struct {
+ B string
+ }
+ }
+ }{},
+ out: []interface{}{
+ struct {
+ AAAAA struct {
+ A struct {
+ A string
+ }
+ }
+ }{},
+ struct {
+ BBBBB struct {
+ B struct {
+ B string
+ }
+ }
+ }{},
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ inType := reflect.TypeOf(test.in)
+ var in []reflect.StructField
+ for i := 0; i < inType.NumField(); i++ {
+ in = append(in, inType.Field(i))
+ }
+
+ keep := func(field reflect.StructField, string string) (bool, reflect.StructField) {
+ return true, field
+ }
+
+ // Test that maxTypeNameSize is the
+ if test.maxTypeNameSize > 0 {
+ correctPanic := false
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ if _, ok := r.(cantFitPanic); ok {
+ correctPanic = true
+ } else {
+ panic(r)
+ }
+ }
+ }()
+
+ _, _ = filterPropertyStructFields(in, "", test.maxTypeNameSize-1, keep)
+ }()
+
+ if !correctPanic {
+ t.Errorf("filterPropertyStructFields() with size-1 should produce cantFitPanic")
+ }
+ }
+
+ filteredFieldsShards, _ := filterPropertyStructFields(in, "", test.maxTypeNameSize, keep)
+
+ var out []interface{}
+ for _, filteredFields := range filteredFieldsShards {
+ typ := reflect.StructOf(filteredFields)
+ if test.maxTypeNameSize > 0 && len(typ.String()) > test.maxTypeNameSize {
+ t.Errorf("out %q expected size <= %d, got %d",
+ typ.String(), test.maxTypeNameSize, len(typ.String()))
+ }
+ out = append(out, reflect.Zero(typ).Interface())
+ }
+
+ if g, w := out, test.out; !reflect.DeepEqual(g, w) {
+ t.Errorf("filterPropertyStructFields() want %v, got %v", w, g)
+ }
+ })
+ }
+}
diff --git a/proptools/proptools.go b/proptools/proptools.go
index 6881828..c44a4a8 100644
--- a/proptools/proptools.go
+++ b/proptools/proptools.go
@@ -15,19 +15,35 @@
package proptools
import (
+ "reflect"
+ "strings"
"unicode"
"unicode/utf8"
)
+// PropertyNameForField converts the name of a field in property struct to the property name that
+// might appear in a Blueprints file. Since the property struct fields must always be exported
+// to be accessed with reflection and the canonical Blueprints style is lowercased names, it
+// lower cases the first rune in the field name unless the field name contains an uppercase rune
+// after the first rune (which is always uppercase), and no lowercase runes.
func PropertyNameForField(fieldName string) string {
r, size := utf8.DecodeRuneInString(fieldName)
propertyName := string(unicode.ToLower(r))
+ if size == len(fieldName) {
+ return propertyName
+ }
+ if strings.IndexFunc(fieldName[size:], unicode.IsLower) == -1 &&
+ strings.IndexFunc(fieldName[size:], unicode.IsUpper) != -1 {
+ return fieldName
+ }
if len(fieldName) > size {
propertyName += fieldName[size:]
}
return propertyName
}
+// FieldNameForProperty converts the name of a property that might appear in a Blueprints file to
+// the name of a field in property struct by uppercasing the first rune.
func FieldNameForProperty(propertyName string) string {
r, size := utf8.DecodeRuneInString(propertyName)
fieldName := string(unicode.ToUpper(r))
@@ -97,3 +113,11 @@
func Int(i *int64) int {
return IntDefault(i, 0)
}
+
+func isStruct(t reflect.Type) bool {
+ return t.Kind() == reflect.Struct
+}
+
+func isStructPtr(t reflect.Type) bool {
+ return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
+}
diff --git a/proptools/proptools_test.go b/proptools/proptools_test.go
new file mode 100644
index 0000000..207ee1b
--- /dev/null
+++ b/proptools/proptools_test.go
@@ -0,0 +1,114 @@
+// Copyright 2020 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proptools
+
+import "testing"
+
+func TestPropertyNameForField(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ name: "short",
+ input: "S",
+ want: "s",
+ },
+ {
+ name: "long",
+ input: "String",
+ want: "string",
+ },
+ {
+ name: "uppercase",
+ input: "STRING",
+ want: "STRING",
+ },
+ {
+ name: "mixed",
+ input: "StRiNg",
+ want: "stRiNg",
+ },
+ {
+ name: "underscore",
+ input: "Under_score",
+ want: "under_score",
+ },
+ {
+ name: "uppercase underscore",
+ input: "UNDER_SCORE",
+ want: "UNDER_SCORE",
+ },
+ {
+ name: "x86",
+ input: "X86",
+ want: "x86",
+ },
+ {
+ name: "x86_64",
+ input: "X86_64",
+ want: "x86_64",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := PropertyNameForField(tt.input); got != tt.want {
+ t.Errorf("PropertyNameForField(%v) = %v, want %v", tt.input, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFieldNameForProperty(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ name: "short lowercase",
+ input: "s",
+ want: "S",
+ },
+ {
+ name: "short uppercase",
+ input: "S",
+ want: "S",
+ },
+ {
+ name: "long lowercase",
+ input: "string",
+ want: "String",
+ },
+ {
+ name: "long uppercase",
+ input: "STRING",
+ want: "STRING",
+ },
+ {
+ name: "mixed",
+ input: "StRiNg",
+ want: "StRiNg",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := FieldNameForProperty(tt.input); got != tt.want {
+ t.Errorf("FieldNameForProperty(%v) = %v, want %v", tt.input, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/proptools/tag.go b/proptools/tag.go
index af5b97e..d69853a 100644
--- a/proptools/tag.go
+++ b/proptools/tag.go
@@ -36,7 +36,7 @@
// are tagged with the given key and value, including ones found in embedded structs or pointers to structs.
func PropertyIndexesWithTag(ps interface{}, key, value string) [][]int {
t := reflect.TypeOf(ps)
- if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
+ if !isStructPtr(t) {
panic(fmt.Errorf("type %s is not a pointer to a struct", t))
}
t = t.Elem()
@@ -49,7 +49,7 @@
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
ft := field.Type
- if ft.Kind() == reflect.Struct || (ft.Kind() == reflect.Ptr && ft.Elem().Kind() == reflect.Struct) {
+ if isStruct(ft) || isStructPtr(ft) {
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
diff --git a/unpack.go b/proptools/unpack.go
similarity index 73%
rename from unpack.go
rename to proptools/unpack.go
index 3156599..344327f 100644
--- a/unpack.go
+++ b/proptools/unpack.go
@@ -12,24 +12,33 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package blueprint
+package proptools
import (
"fmt"
"reflect"
- "strconv"
- "strings"
+ "text/scanner"
"github.com/google/blueprint/parser"
- "github.com/google/blueprint/proptools"
)
+const maxUnpackErrors = 10
+
+type UnpackError struct {
+ Err error
+ Pos scanner.Position
+}
+
+func (e *UnpackError) Error() string {
+ return fmt.Sprintf("%s: %s", e.Pos, e.Err)
+}
+
type packedProperty struct {
property *parser.Property
unpacked bool
}
-func unpackProperties(propertyDefs []*parser.Property,
+func UnpackProperties(propertyDefs []*parser.Property,
propertiesStructs ...interface{}) (map[string]*parser.Property, []error) {
propertyMap := make(map[string]*packedProperty)
@@ -40,19 +49,16 @@
for _, properties := range propertiesStructs {
propertiesValue := reflect.ValueOf(properties)
- if propertiesValue.Kind() != reflect.Ptr {
- panic("properties must be a pointer to a struct")
+ if !isStructPtr(propertiesValue.Type()) {
+ panic(fmt.Errorf("properties must be *struct, got %s",
+ propertiesValue.Type()))
}
-
propertiesValue = propertiesValue.Elem()
- if propertiesValue.Kind() != reflect.Struct {
- panic("properties must be a pointer to a struct")
- }
- newErrs := unpackStructValue("", propertiesValue, propertyMap, "", "")
+ newErrs := unpackStructValue("", propertiesValue, propertyMap)
errs = append(errs, newErrs...)
- if len(errs) >= maxErrors {
+ if len(errs) >= maxUnpackErrors {
return nil, errs
}
}
@@ -63,7 +69,7 @@
for name, packedProperty := range propertyMap {
result[name] = packedProperty.property
if !packedProperty.unpacked {
- err := &BlueprintError{
+ err := &UnpackError{
Err: fmt.Errorf("unrecognized property %q", name),
Pos: packedProperty.property.ColonPos,
}
@@ -88,15 +94,15 @@
// We've already added this property.
continue
}
- errs = append(errs, &BlueprintError{
+ errs = append(errs, &UnpackError{
Err: fmt.Errorf("property %q already defined", name),
Pos: propertyDef.ColonPos,
})
- errs = append(errs, &BlueprintError{
+ errs = append(errs, &UnpackError{
Err: fmt.Errorf("<-- previous definition here"),
Pos: first.property.ColonPos,
})
- if len(errs) >= maxErrors {
+ if len(errs) >= maxUnpackErrors {
return errs
}
continue
@@ -119,7 +125,7 @@
}
func unpackStructValue(namePrefix string, structValue reflect.Value,
- propertyMap map[string]*packedProperty, filterKey, filterValue string) []error {
+ propertyMap map[string]*packedProperty) []error {
structType := structValue.Type()
@@ -142,7 +148,7 @@
continue
}
- propertyName := namePrefix + proptools.PropertyNameForField(field.Name)
+ propertyName := namePrefix + PropertyNameForField(field.Name)
if !fieldValue.CanSet() {
panic(fmt.Errorf("field %s is not settable", propertyName))
@@ -163,7 +169,7 @@
case reflect.Slice:
elemType := field.Type.Elem()
if elemType.Kind() != reflect.String {
- if !proptools.HasTag(field, "blueprint", "mutated") {
+ if !HasTag(field, "blueprint", "mutated") {
panic(fmt.Errorf("field %s is a non-string slice", propertyName))
}
}
@@ -195,7 +201,7 @@
}
case reflect.Int, reflect.Uint:
- if !proptools.HasTag(field, "blueprint", "mutated") {
+ if !HasTag(field, "blueprint", "mutated") {
panic(fmt.Errorf(`int field %s must be tagged blueprint:"mutated"`, propertyName))
}
@@ -203,8 +209,8 @@
panic(fmt.Errorf("unsupported kind for field %s: %s", propertyName, kind))
}
- if field.Anonymous && fieldValue.Kind() == reflect.Struct {
- newErrs := unpackStructValue(namePrefix, fieldValue, propertyMap, filterKey, filterValue)
+ if field.Anonymous && isStruct(fieldValue.Type()) {
+ newErrs := unpackStructValue(namePrefix, fieldValue, propertyMap)
errs = append(errs, newErrs...)
continue
}
@@ -216,25 +222,13 @@
packedProperty.unpacked = true
- if proptools.HasTag(field, "blueprint", "mutated") {
+ if HasTag(field, "blueprint", "mutated") {
errs = append(errs,
- &BlueprintError{
+ &UnpackError{
Err: fmt.Errorf("mutated field %s cannot be set in a Blueprint file", propertyName),
Pos: packedProperty.property.ColonPos,
})
- if len(errs) >= maxErrors {
- return errs
- }
- continue
- }
-
- if filterKey != "" && !proptools.HasTag(field, filterKey, filterValue) {
- errs = append(errs,
- &BlueprintError{
- Err: fmt.Errorf("filtered field %s cannot be set in a Blueprint file", propertyName),
- Pos: packedProperty.property.ColonPos,
- })
- if len(errs) >= maxErrors {
+ if len(errs) >= maxUnpackErrors {
return errs
}
continue
@@ -242,29 +236,12 @@
var newErrs []error
- if fieldValue.Kind() == reflect.Struct {
- localFilterKey, localFilterValue := filterKey, filterValue
- if k, v, err := HasFilter(field.Tag); err != nil {
- errs = append(errs, err)
- if len(errs) >= maxErrors {
- return errs
- }
- } else if k != "" {
- if filterKey != "" {
- errs = append(errs, fmt.Errorf("nested filter tag not supported on field %q",
- field.Name))
- if len(errs) >= maxErrors {
- return errs
- }
- } else {
- localFilterKey, localFilterValue = k, v
- }
- }
+ if isStruct(fieldValue.Type()) {
newErrs = unpackStruct(propertyName+".", fieldValue,
- packedProperty.property, propertyMap, localFilterKey, localFilterValue)
+ packedProperty.property, propertyMap)
errs = append(errs, newErrs...)
- if len(errs) >= maxErrors {
+ if len(errs) >= maxUnpackErrors {
return errs
}
@@ -276,12 +253,12 @@
propertyValue, err := propertyToValue(fieldValue.Type(), packedProperty.property)
if err != nil {
errs = append(errs, err)
- if len(errs) >= maxErrors {
+ if len(errs) >= maxUnpackErrors {
return errs
}
}
- proptools.ExtendBasicType(fieldValue, propertyValue, proptools.Append)
+ ExtendBasicType(fieldValue, propertyValue, Append)
}
return errs
@@ -354,8 +331,7 @@
}
func unpackStruct(namePrefix string, structValue reflect.Value,
- property *parser.Property, propertyMap map[string]*packedProperty,
- filterKey, filterValue string) []error {
+ property *parser.Property, propertyMap map[string]*packedProperty) []error {
m, ok := property.Value.Eval().(*parser.Map)
if !ok {
@@ -370,31 +346,5 @@
return errs
}
- return unpackStructValue(namePrefix, structValue, propertyMap, filterKey, filterValue)
-}
-
-func HasFilter(field reflect.StructTag) (k, v string, err error) {
- tag := field.Get("blueprint")
- for _, entry := range strings.Split(tag, ",") {
- if strings.HasPrefix(entry, "filter") {
- if !strings.HasPrefix(entry, "filter(") || !strings.HasSuffix(entry, ")") {
- return "", "", fmt.Errorf("unexpected format for filter %q: missing ()", entry)
- }
- entry = strings.TrimPrefix(entry, "filter(")
- entry = strings.TrimSuffix(entry, ")")
-
- s := strings.Split(entry, ":")
- if len(s) != 2 {
- return "", "", fmt.Errorf("unexpected format for filter %q: expected single ':'", entry)
- }
- k = s[0]
- v, err = strconv.Unquote(s[1])
- if err != nil {
- return "", "", fmt.Errorf("unexpected format for filter %q: %s", entry, err.Error())
- }
- return k, v, nil
- }
- }
-
- return "", "", nil
+ return unpackStructValue(namePrefix, structValue, propertyMap)
}
diff --git a/unpack_test.go b/proptools/unpack_test.go
similarity index 73%
rename from unpack_test.go
rename to proptools/unpack_test.go
index d6b88ab..942dbb8 100644
--- a/unpack_test.go
+++ b/proptools/unpack_test.go
@@ -12,17 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package blueprint
+package proptools
import (
"bytes"
- "fmt"
"reflect"
"testing"
"text/scanner"
"github.com/google/blueprint/parser"
- "github.com/google/blueprint/proptools"
)
var validUnpackTestCases = []struct {
@@ -34,18 +32,18 @@
{
input: `
m {
- name: "abc",
+ s: "abc",
blank: "",
}
`,
output: []interface{}{
- struct {
- Name *string
+ &struct {
+ S *string
Blank *string
Unset *string
}{
- Name: proptools.StringPtr("abc"),
- Blank: proptools.StringPtr(""),
+ S: StringPtr("abc"),
+ Blank: StringPtr(""),
Unset: nil,
},
},
@@ -54,14 +52,14 @@
{
input: `
m {
- name: "abc",
+ s: "abc",
}
`,
output: []interface{}{
- struct {
- Name string
+ &struct {
+ S string
}{
- Name: "abc",
+ S: "abc",
},
},
},
@@ -73,7 +71,7 @@
}
`,
output: []interface{}{
- struct {
+ &struct {
IsGood bool
}{
IsGood: true,
@@ -89,13 +87,13 @@
}
`,
output: []interface{}{
- struct {
+ &struct {
IsGood *bool
IsBad *bool
IsUgly *bool
}{
- IsGood: proptools.BoolPtr(true),
- IsBad: proptools.BoolPtr(false),
+ IsGood: BoolPtr(true),
+ IsBad: BoolPtr(false),
IsUgly: nil,
},
},
@@ -110,7 +108,7 @@
}
`,
output: []interface{}{
- struct {
+ &struct {
Stuff []string
Empty []string
Nil []string
@@ -128,18 +126,18 @@
input: `
m {
nested: {
- name: "abc",
+ s: "abc",
}
}
`,
output: []interface{}{
- struct {
+ &struct {
Nested struct {
- Name string
+ S string
}
}{
- Nested: struct{ Name string }{
- Name: "abc",
+ Nested: struct{ S string }{
+ S: "abc",
},
},
},
@@ -149,16 +147,16 @@
input: `
m {
nested: {
- name: "def",
+ s: "def",
}
}
`,
output: []interface{}{
- struct {
+ &struct {
Nested interface{}
}{
- Nested: &struct{ Name string }{
- Name: "def",
+ Nested: &struct{ S string }{
+ S: "def",
},
},
},
@@ -175,7 +173,7 @@
}
`,
output: []interface{}{
- struct {
+ &struct {
Nested struct {
Foo string
}
@@ -202,7 +200,7 @@
}
`,
output: []interface{}{
- struct {
+ &struct {
Nested struct {
Foo string `allowNested:"true"`
} `blueprint:"filter(allowNested:\"true\")"`
@@ -220,64 +218,31 @@
},
},
- {
- input: `
- m {
- nested: {
- foo: "abc",
- },
- bar: false,
- baz: ["def", "ghi"],
- }
- `,
- output: []interface{}{
- struct {
- Nested struct {
- Foo string
- } `blueprint:"filter(allowNested:\"true\")"`
- Bar bool
- Baz []string
- }{
- Nested: struct{ Foo string }{
- Foo: "",
- },
- Bar: false,
- Baz: []string{"def", "ghi"},
- },
- },
- errs: []error{
- &BlueprintError{
- Err: fmt.Errorf("filtered field nested.foo cannot be set in a Blueprint file"),
- Pos: mkpos(30, 4, 9),
- },
- },
- },
-
// Anonymous struct
{
input: `
m {
- name: "abc",
+ s: "abc",
nested: {
- name: "def",
+ s: "def",
},
}
`,
output: []interface{}{
- struct {
+ &struct {
EmbeddedStruct
Nested struct {
EmbeddedStruct
}
}{
EmbeddedStruct: EmbeddedStruct{
- Name: "abc",
+ S: "abc",
},
Nested: struct {
EmbeddedStruct
}{
EmbeddedStruct: EmbeddedStruct{
- Name: "def",
+ S: "def",
},
},
},
@@ -288,27 +253,27 @@
{
input: `
m {
- name: "abc",
+ s: "abc",
nested: {
- name: "def",
+ s: "def",
},
}
`,
output: []interface{}{
- struct {
+ &struct {
EmbeddedInterface
Nested struct {
EmbeddedInterface
}
}{
- EmbeddedInterface: &struct{ Name string }{
- Name: "abc",
+ EmbeddedInterface: &struct{ S string }{
+ S: "abc",
},
Nested: struct {
EmbeddedInterface
}{
- EmbeddedInterface: &struct{ Name string }{
- Name: "def",
+ EmbeddedInterface: &struct{ S string }{
+ S: "def",
},
},
},
@@ -319,32 +284,32 @@
{
input: `
m {
- name: "abc",
+ s: "abc",
nested: {
- name: "def",
+ s: "def",
},
}
`,
output: []interface{}{
- struct {
- Name string
+ &struct {
+ S string
EmbeddedStruct
Nested struct {
- Name string
+ S string
EmbeddedStruct
}
}{
- Name: "abc",
+ S: "abc",
EmbeddedStruct: EmbeddedStruct{
- Name: "abc",
+ S: "abc",
},
Nested: struct {
- Name string
+ S string
EmbeddedStruct
}{
- Name: "def",
+ S: "def",
EmbeddedStruct: EmbeddedStruct{
- Name: "def",
+ S: "def",
},
},
},
@@ -355,32 +320,32 @@
{
input: `
m {
- name: "abc",
+ s: "abc",
nested: {
- name: "def",
+ s: "def",
},
}
`,
output: []interface{}{
- struct {
- Name string
+ &struct {
+ S string
EmbeddedInterface
Nested struct {
- Name string
+ S string
EmbeddedInterface
}
}{
- Name: "abc",
- EmbeddedInterface: &struct{ Name string }{
- Name: "abc",
+ S: "abc",
+ EmbeddedInterface: &struct{ S string }{
+ S: "abc",
},
Nested: struct {
- Name string
+ S string
EmbeddedInterface
}{
- Name: "def",
- EmbeddedInterface: &struct{ Name string }{
- Name: "def",
+ S: "def",
+ EmbeddedInterface: &struct{ S string }{
+ S: "def",
},
},
},
@@ -394,18 +359,18 @@
string = "def"
list_with_variable = [string]
m {
- name: string,
+ s: string,
list: list,
list2: list_with_variable,
}
`,
output: []interface{}{
- struct {
- Name string
+ &struct {
+ S string
List []string
List2 []string
}{
- Name: "def",
+ S: "def",
List: []string{"abc"},
List2: []string{"def"},
},
@@ -417,30 +382,30 @@
input: `
m {
nested: {
- name: "abc",
+ s: "abc",
}
}
`,
output: []interface{}{
- struct {
+ &struct {
Nested struct {
- Name string
+ S string
}
}{
- Nested: struct{ Name string }{
- Name: "abc",
+ Nested: struct{ S string }{
+ S: "abc",
},
},
- struct {
+ &struct {
Nested struct {
- Name string
+ S string
}
}{
- Nested: struct{ Name string }{
- Name: "abc",
+ Nested: struct{ S string }{
+ S: "abc",
},
},
- struct {
+ &struct {
}{},
},
},
@@ -450,25 +415,25 @@
input: `
m {
nested: {
- name: "abc",
+ s: "abc",
}
}
`,
output: []interface{}{
- struct {
+ &struct {
Nested *struct {
- Name string
+ S string
}
}{
- Nested: &struct{ Name string }{
- Name: "abc",
+ Nested: &struct{ S string }{
+ S: "abc",
},
},
},
empty: []interface{}{
&struct {
Nested *struct {
- Name string
+ S string
}
}{},
},
@@ -479,16 +444,16 @@
input: `
m {
nested: {
- name: "abc",
+ s: "abc",
}
}
`,
output: []interface{}{
- struct {
+ &struct {
Nested interface{}
}{
Nested: &EmbeddedStruct{
- Name: "abc",
+ S: "abc",
},
},
},
@@ -513,7 +478,7 @@
}
`,
output: []interface{}{
- struct {
+ &struct {
String string
String_ptr *string
Bool bool
@@ -521,9 +486,9 @@
List []string
}{
String: "012abc",
- String_ptr: proptools.StringPtr("abc"),
+ String_ptr: StringPtr("abc"),
Bool: true,
- Bool_ptr: proptools.BoolPtr(false),
+ Bool_ptr: BoolPtr(false),
List: []string{"0", "1", "2", "a", "b", "c"},
},
},
@@ -536,18 +501,30 @@
List []string
}{
String: "012",
- String_ptr: proptools.StringPtr("012"),
+ String_ptr: StringPtr("012"),
Bool: true,
- Bool_ptr: proptools.BoolPtr(true),
+ Bool_ptr: BoolPtr(true),
List: []string{"0", "1", "2"},
},
},
},
+ // Captitalized property
+ {
+ input: `
+ m {
+ CAPITALIZED: "foo",
+ }
+ `,
+ output: []interface{}{
+ &struct {
+ CAPITALIZED string
+ }{
+ CAPITALIZED: "foo",
+ },
+ },
+ },
}
-type EmbeddedStruct struct{ Name string }
-type EmbeddedInterface interface{}
-
func TestUnpackProperties(t *testing.T) {
for _, testCase := range validUnpackTestCases {
r := bytes.NewBufferString(testCase.input)
@@ -572,10 +549,10 @@
output = testCase.empty
} else {
for _, p := range testCase.output {
- output = append(output, proptools.CloneEmptyProperties(reflect.ValueOf(p)).Interface())
+ output = append(output, CloneEmptyProperties(reflect.ValueOf(p)).Interface())
}
}
- _, errs = unpackProperties(module.Properties, output...)
+ _, errs = UnpackProperties(module.Properties, output...)
if len(errs) != 0 && len(testCase.errs) == 0 {
t.Errorf("test case: %s", testCase.input)
t.Errorf("unexpected unpack errors:")
@@ -596,7 +573,7 @@
}
for i := range output {
- got := reflect.ValueOf(output[i]).Elem().Interface()
+ got := reflect.ValueOf(output[i]).Interface()
if !reflect.DeepEqual(got, testCase.output[i]) {
t.Errorf("test case: %s", testCase.input)
t.Errorf("incorrect output:")
diff --git a/scope.go b/scope.go
index 84db0cf..0a520d9 100644
--- a/scope.go
+++ b/scope.go
@@ -28,7 +28,7 @@
packageContext() *packageContext
name() string // "foo"
fullName(pkgNames map[*packageContext]string) string // "pkg.foo" or "path.to.pkg.foo"
- value(config interface{}) (*ninjaString, error)
+ value(config interface{}) (ninjaString, error)
String() string
}
@@ -351,7 +351,7 @@
type localVariable struct {
namePrefix string
name_ string
- value_ *ninjaString
+ value_ ninjaString
}
func (l *localVariable) packageContext() *packageContext {
@@ -366,7 +366,7 @@
return l.namePrefix + l.name_
}
-func (l *localVariable) value(interface{}) (*ninjaString, error) {
+func (l *localVariable) value(interface{}) (ninjaString, error) {
return l.value_, nil
}
diff --git a/visit_test.go b/visit_test.go
index 873e72c..efaadba 100644
--- a/visit_test.go
+++ b/visit_test.go
@@ -125,7 +125,7 @@
`),
})
- _, errs := ctx.ParseBlueprintsFiles("Blueprints")
+ _, errs := ctx.ParseBlueprintsFiles("Blueprints", nil)
if len(errs) > 0 {
t.Errorf("unexpected parse errors:")
for _, err := range errs {
@@ -149,13 +149,13 @@
func TestVisit(t *testing.T) {
ctx := setupVisitTest(t)
- topModule := ctx.modulesFromName("A", nil)[0].logicModule.(*visitModule)
+ topModule := ctx.moduleGroupFromName("A", nil).modules[0].logicModule.(*visitModule)
assertString(t, topModule.properties.VisitDepsDepthFirst, "FEDCB")
assertString(t, topModule.properties.VisitDepsDepthFirstIf, "FEDC")
assertString(t, topModule.properties.VisitDirectDeps, "B")
assertString(t, topModule.properties.VisitDirectDepsIf, "")
- eModule := ctx.modulesFromName("E", nil)[0].logicModule.(*visitModule)
+ eModule := ctx.moduleGroupFromName("E", nil).modules[0].logicModule.(*visitModule)
assertString(t, eModule.properties.VisitDepsDepthFirst, "F")
assertString(t, eModule.properties.VisitDepsDepthFirstIf, "F")
assertString(t, eModule.properties.VisitDirectDeps, "FF")
diff --git a/vnames.go.json b/vnames.go.json
new file mode 100644
index 0000000..ba239c1
--- /dev/null
+++ b/vnames.go.json
@@ -0,0 +1,9 @@
+[
+ {
+ "pattern": "(.*)",
+ "vname": {
+ "corpus": "android.googlesource.com/platform/superproject",
+ "path": "build/blueprint/@1@"
+ }
+ }
+]