op_subset {keras3}R Documentation

Subset elements from a tensor

Description

Extract elements from a tensor using common R-style [ indexing idioms. This function can also be conveniently accessed via the syntax tensor@r[...].

Usage

op_subset(x, ...)

op_subset(x, ...) <- value

op_subset_set(x, ..., value)

Arguments

x

Input tensor.

...

Indices specifying elements to extract. Each argument in ... can be:

  • An integer scalar

  • A 1-d integer or logical vector

  • NULL or newaxis

  • The .. symbol

  • A slice expression using :

If only a single argument is supplied to ..., then ..1 can also be:

  • A logical array with the same shape as x

  • An integer matrix where ncol(..1) == op_rank(x)

value

new value to replace the selected subset with.

Details

While the semantics are similar to R's [, there are some differences:

Value

A tensor containing the subset of elements.

Differences from R's [:

Similarities with R's [:

Similarities to R's [ (differences from Python's [):

Examples

(x <- op_arange(5L) + 10L)
## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

# Basic example, get first element
op_subset(x, 1)
## tf.Tensor(11, shape=(), dtype=int32)

# Use `@r[` syntax
x@r[1]           # same as `op_subset(x, 1)`
## tf.Tensor(11, shape=(), dtype=int32)

x@r[1:2]         # get the first 2 elements
## tf.Tensor([11 12], shape=(2), dtype=int32)

x@r[c(1, 3)]     # first and third element
## tf.Tensor([11 13], shape=(2), dtype=int32)

# Negative indices
x@r[-1]          # last element
## tf.Tensor(15, shape=(), dtype=int32)

x@r[-2]          # second to last element
## tf.Tensor(14, shape=(), dtype=int32)

x@r[c(-1, -2)]   # last and second to last elements
## tf.Tensor([15 14], shape=(2), dtype=int32)

x@r[c(-2, -1)]   # second to last and last elements
## tf.Tensor([14 15], shape=(2), dtype=int32)

x@r[c(1, -1)]    # first and last elements
## tf.Tensor([11 15], shape=(2), dtype=int32)

# Slices
x@r[1:3]          # first 3 elements
## tf.Tensor([11 12 13], shape=(3), dtype=int32)

x@r[NA:3]         # first 3 elements
## tf.Tensor([11 12 13], shape=(3), dtype=int32)

x@r[1:5]          # all elements
## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[1:-1]         # all elements
## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[NA:NA]        # all elements
## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[]             # all elements
## tf.Tensor([11 12 13 14 15], shape=(5), dtype=int32)

x@r[1:-2]         # drop last element
## tf.Tensor([11 12 13 14], shape=(4), dtype=int32)

x@r[NA:-2]        # drop last element
## tf.Tensor([11 12 13 14], shape=(4), dtype=int32)

x@r[2:NA]         # drop first element
## tf.Tensor([12 13 14 15], shape=(4), dtype=int32)

# 2D array examples
xr <- array(1:12, c(3, 4))
x <- op_convert_to_tensor(xr)

# Basic subsetting
x@r[1, ]      # first row
## tf.Tensor([ 1  4  7 10], shape=(4), dtype=int32)

x@r[1]        # also first row! Missing axes are implicitly inserted
## tf.Tensor([ 1  4  7 10], shape=(4), dtype=int32)

x@r[-1]       # last row
## tf.Tensor([ 3  6  9 12], shape=(4), dtype=int32)

x@r[, 2]      # second column
## tf.Tensor([4 5 6], shape=(3), dtype=int32)

x@r[, 2:2]    # second column, but shape preserved (like [, drop=FALSE])
## tf.Tensor(
## [[4]
##  [5]
##  [6]], shape=(3, 1), dtype=int32)

# Subsetting with a boolean array
# Note: extracted elements are selected row-wise, not column-wise
mask <- x >= 6
x@r[mask]             # returns a 1D tensor
## tf.Tensor([ 7 10  8 11  6  9 12], shape=(7), dtype=int32)

x.r <- as.array(x)
mask.r <- as.array(mask)
# as.array(x)[mask] selects column-wise. Use `aperm()` to reverse search order.
all(aperm(x.r)[aperm(mask.r)] == as.array(x@r[mask]))
## [1] TRUE

# Subsetting with a matrix of index positions
indices <- rbind(c(1, 1), c(2, 2), c(3, 3))
x@r[indices] # get diagonal elements
## tf.Tensor([1 5 9], shape=(3), dtype=int32)

x.r[indices] # same as subsetting an R array
## [1] 1 5 9

# 3D array examples
# Image: 4x4 pixels, 3 colors (RGB)
# Tensor shape: (img_height, img_width, img_color_channels)
shp <- shape(4, 4, 3)
x <- op_arange(prod(shp)) |> op_reshape(shp)

# Convert to a batch of images by inserting a new axis
# New shape: (batch_size, img_height, img_width, img_color_channels)
x@r[newaxis, , , ] |> op_shape()
## shape(1, 4, 4, 3)

x@r[newaxis] |> op_shape()  # same as above
## shape(1, 4, 4, 3)

x@r[NULL] |> op_shape()     # same as above
## shape(1, 4, 4, 3)

x <- x@r[newaxis]
# Extract color channels
x@r[, , , 1]          # red channel
## tf.Tensor(
## [[[ 1.  4.  7. 10.]
##   [13. 16. 19. 22.]
##   [25. 28. 31. 34.]
##   [37. 40. 43. 46.]]], shape=(1, 4, 4), dtype=float32)

x@r[.., 1]            # red channel, same as above using .. shorthand
## tf.Tensor(
## [[[ 1.  4.  7. 10.]
##   [13. 16. 19. 22.]
##   [25. 28. 31. 34.]
##   [37. 40. 43. 46.]]], shape=(1, 4, 4), dtype=float32)

x@r[.., 2]            # green channel
## tf.Tensor(
## [[[ 2.  5.  8. 11.]
##   [14. 17. 20. 23.]
##   [26. 29. 32. 35.]
##   [38. 41. 44. 47.]]], shape=(1, 4, 4), dtype=float32)

x@r[.., 3]            # blue channel
## tf.Tensor(
## [[[ 3.  6.  9. 12.]
##   [15. 18. 21. 24.]
##   [27. 30. 33. 36.]
##   [39. 42. 45. 48.]]], shape=(1, 4, 4), dtype=float32)

# .. expands to all unspecified axes.
op_shape(x@r[])
## shape(1, 4, 4, 3)

op_shape(x@r[..])
## shape(1, 4, 4, 3)

op_shape(x@r[1, ..])
## shape(4, 4, 3)

op_shape(x@r[1, .., 1, 1])
## shape(4)

op_shape(x@r[1, 1, 1, .., 1])
## shape()

# op_subset<- uses the same semantics, but note that not all tensors
# support modification. E.g., TensorFlow constant tensors cannot be modified,
# while TensorFlow Variables can be.

(x <- tensorflow::tf$Variable(matrix(1, nrow = 2, ncol = 3)))
## <tf.Variable 'Variable:0' shape=(2, 3) dtype=float64, numpy=
## array([[1., 1., 1.],
##        [1., 1., 1.]])>

op_subset(x, 1) <- 9
x
## <tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float64, numpy=
## array([[9., 9., 9.],
##        [1., 1., 1.]])>

x@r[1,1] <- 33
x
## <tf.Variable 'UnreadVariable' shape=(2, 3) dtype=float64, numpy=
## array([[33.,  9.,  9.],
##        [ 1.,  1.,  1.]])>

See Also

Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_rearrange()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dot_product_attention()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_gaussian_blur()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_perspective_transform()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_polar()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_rearrange()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_rms_normalization()
op_roll()
op_rot90()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_signbit()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()

Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_rearrange()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dot_product_attention()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_gaussian_blur()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_perspective_transform()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_polar()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_rearrange()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_rms_normalization()
op_roll()
op_rot90()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_signbit()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()

Other core ops:
op_associative_scan()
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_rearrange()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()

Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_associative_scan()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_bitwise_and()
op_bitwise_invert()
op_bitwise_left_shift()
op_bitwise_not()
op_bitwise_or()
op_bitwise_right_shift()
op_bitwise_xor()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_celu()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagflat()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dot_product_attention()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_exp2()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_glu()
op_greater()
op_greater_equal()
op_hard_shrink()
op_hard_sigmoid()
op_hard_silu()
op_hard_tanh()
op_histogram()
op_hstack()
op_identity()
op_ifft2()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_gaussian_blur()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_perspective_transform()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inner()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_left_shift()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logdet()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_polar()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_rearrange()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_right_shift()
op_rms_normalization()
op_roll()
op_rot90()
op_round()
op_rsqrt()
op_saturate_cast()
op_scan()
op_scatter()
op_scatter_update()
op_searchsorted()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_signbit()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_soft_shrink()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_sparse_plus()
op_sparsemax()
op_split()
op_sqrt()
op_square()
op_squareplus()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tanh_shrink()
op_tensordot()
op_threshold()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_trunc()
op_unravel_index()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()


[Package keras3 version 1.4.0 Index]