blob: 68b78be790bca25b3036b5ae936cd9f6ca257918 [file] [log] [blame]
op {
graph_op_name: "GatherNd"
in_arg {
name: "params"
description: <<END
The tensor from which to gather values.
END
}
in_arg {
name: "indices"
description: <<END
Index tensor.
END
}
out_arg {
name: "output"
description: <<END
Values from `params` gathered from indices given by `indices`, with
shape `indices.shape[:-1] + params.shape[indices.shape[-1]:]`.
END
}
summary: "Gather slices from `params` into a Tensor with shape specified by `indices`."
description: <<END
`indices` is a K-dimensional integer tensor, best thought of as a
(K-1)-dimensional tensor of indices into `params`, where each element defines a
slice of `params`:
output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
Whereas in `tf.gather` `indices` defines slices into the `axis`
dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
first `N` dimensions of `params`, where `N = indices.shape[-1]`.
The last dimension of `indices` can be at most the rank of
`params`:
indices.shape[-1] <= params.rank
The last dimension of `indices` corresponds to elements
(if `indices.shape[-1] == params.rank`) or slices
(if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]`
of `params`. The output tensor has shape
indices.shape[:-1] + params.shape[indices.shape[-1]:]
Note that on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, a 0 is stored in the
corresponding output value.
Some examples below.
Simple indexing into a matrix:
```python
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
```
Slice indexing into a matrix:
```python
indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]
```
Indexing into a 3-tensor:
```python
indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]
indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]
indices = [[0, 0, 1], [1, 0, 1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = ['b0', 'b1']
```
Batched indexing into a matrix:
```python
indices = [[[0, 0]], [[0, 1]]]
params = [['a', 'b'], ['c', 'd']]
output = [['a'], ['b']]
```
Batched slice indexing into a matrix:
```python
indices = [[[1]], [[0]]]
params = [['a', 'b'], ['c', 'd']]
output = [[['c', 'd']], [['a', 'b']]]
```
Batched indexing into a 3-tensor:
```python
indices = [[[1]], [[0]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[[['a1', 'b1'], ['c1', 'd1']]],
[[['a0', 'b0'], ['c0', 'd0']]]]
indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['c0', 'd0'], ['a1', 'b1']],
[['a0', 'b0'], ['c1', 'd1']]]
indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['b0', 'b1'], ['d0', 'c1']]
```
See also `tf.gather` and `tf.batch_gather`.
END
}