blob: d70e0d6c0ab2cfefb55ef04ad67b2a74f83ac48d [file] [log] [blame]
# Description:
# TensorFlow Java API.
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
load(":build_defs.bzl", "JAVACOPTS")
load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
load(
"//tensorflow:tensorflow.bzl",
"tf_binary_additional_srcs",
"tf_cc_binary",
"tf_copts",
"tf_custom_op_library",
"tf_java_test",
"tf_cc_test",
)
java_library(
name = "tensorflow",
srcs = [
":java_op_sources",
":java_sources",
],
data = [":libtensorflow_jni"],
javacopts = JAVACOPTS,
plugins = [":processor"],
visibility = ["//visibility:public"],
)
# NOTE(ashankar): Rule to include the Java API in the Android Inference Library
# .aar. At some point, might make sense for a .aar rule here instead.
filegroup(
name = "java_sources",
srcs = glob([
"src/main/java/org/tensorflow/*.java",
"src/main/java/org/tensorflow/types/*.java",
]),
visibility = [
"//tensorflow/contrib/android:__pkg__",
"//tensorflow/java:__pkg__",
],
)
java_plugin(
name = "processor",
generates_api = True,
processor_class = "org.tensorflow.processor.OperatorProcessor",
visibility = ["//visibility:public"],
deps = [":processor_library"],
)
java_library(
name = "processor_library",
srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
javacopts = JAVACOPTS,
resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
deps = [
"@com_google_guava",
"@com_squareup_javapoet",
],
)
filegroup(
name = "java_op_sources",
srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [":java_op_gen_sources"],
visibility = [
"//tensorflow/java:__pkg__",
],
)
tf_java_op_gen_srcjar(
name = "java_op_gen_sources",
api_def_srcs = [
"//tensorflow/core/api_def:base_api_def",
"//tensorflow/core/api_def:java_api_def",
],
base_package = "org.tensorflow.op",
gen_tool = ":java_op_gen_tool",
)
tf_cc_binary(
name = "java_op_gen_tool",
srcs = [
"src/gen/cc/op_gen_main.cc",
],
copts = tf_copts(),
linkopts = select({
"//tensorflow:windows": [],
"//conditions:default": ["-lm"],
}),
linkstatic = 1,
deps = [
":java_op_gen_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:ops",
],
)
cc_library(
name = "java_op_gen_lib",
srcs = [
"src/gen/cc/op_generator.cc",
"src/gen/cc/op_specs.cc",
"src/gen/cc/source_writer.cc",
],
hdrs = [
"src/gen/cc/java_defs.h",
"src/gen/cc/op_generator.h",
"src/gen/cc/op_specs.h",
"src/gen/cc/source_writer.h",
],
copts = tf_copts(),
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_all_cc",
"@com_googlesource_code_re2//:re2",
],
)
java_library(
name = "testutil",
testonly = 1,
srcs = ["src/test/java/org/tensorflow/TestUtil.java"],
javacopts = JAVACOPTS,
deps = [":tensorflow"],
)
tf_java_test(
name = "GraphTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/GraphTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.GraphTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "OperationBuilderTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/OperationBuilderTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.OperationBuilderTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "OperationTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/OperationTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.OperationTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "SavedModelBundleTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/SavedModelBundleTest.java"],
data = ["//tensorflow/cc/saved_model:saved_model_half_plus_two"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.SavedModelBundleTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "SessionTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/SessionTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.SessionTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "ShapeTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/ShapeTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.ShapeTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_custom_op_library(
name = "my_test_op.so",
srcs = ["src/test/native/my_test_op.cc"],
)
tf_java_test(
name = "TensorFlowTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"],
data = [":my_test_op.so"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.TensorFlowTest",
deps = [
":tensorflow",
"@junit",
],
)
tf_java_test(
name = "TensorTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/TensorTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.TensorTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "ScopeTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/ScopeTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.ScopeTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "PrimitiveOpTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/PrimitiveOpTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.PrimitiveOpTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "OperandsTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/OperandsTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.OperandsTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "ConstantTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.core.ConstantTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "GeneratedOperationsTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.core.GeneratedOperationsTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "GradientsTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.core.GradientsTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
tf_java_test(
name = "ZerosTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.op.core.ZerosTest",
deps = [
":tensorflow",
":testutil",
"@junit",
],
)
filegroup(
name = "processor_test_resources",
srcs = glob([
"src/test/resources/org/tensorflow/**/*.java",
"src/main/java/org/tensorflow/op/annotation/Operator.java",
]),
)
tf_cc_test(
name = "source_writer_test",
size = "small",
srcs = [
"src/gen/cc/source_writer_test.cc",
],
data = [
"src/gen/resources/test.java.snippet",
],
deps = [
":java_op_gen_lib",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
filegroup(
name = "libtensorflow_jni",
srcs = select({
"//tensorflow:windows": [":tensorflow_jni.dll"],
"//tensorflow:macos": [":libtensorflow_jni.dylib"],
"//conditions:default": [":libtensorflow_jni.so"],
}),
visibility = ["//visibility:public"],
)
LINKER_VERSION_SCRIPT = ":config/version_script.lds"
LINKER_EXPORTED_SYMBOLS = ":config/exported_symbols.lds"
tf_cc_binary(
name = "tensorflow_jni",
# Set linker options to strip out anything except the JNI
# symbols from the library. This reduces the size of the library
# considerably (~50% as of January 2017).
linkopts = select({
"//tensorflow:debug": [], # Disable all custom linker options in debug mode
"//tensorflow:macos": [
"-Wl,-exported_symbols_list,$(location {})".format(LINKER_EXPORTED_SYMBOLS),
],
"//tensorflow:windows": [],
"//conditions:default": [
"-z defs",
"-s",
"-Wl,--version-script,$(location {})".format(LINKER_VERSION_SCRIPT),
],
}),
linkshared = 1,
linkstatic = 1,
per_os_targets = True,
deps = [
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/java/src/main/native",
LINKER_VERSION_SCRIPT,
LINKER_EXPORTED_SYMBOLS,
],
)
genrule(
name = "pom",
outs = ["pom.xml"],
cmd = "$(location generate_pom) >$@",
output_to_bindir = 1,
tools = [":generate_pom"] + tf_binary_additional_srcs(),
)
tf_cc_binary(
name = "generate_pom",
srcs = ["generate_pom.cc"],
deps = ["//tensorflow/c:c_api"],
)