| op_scan {keras3} | R Documentation |
Scan a function over leading array axes while carrying along state.
Description
When the type of xs is an array type or NULL, and the type of ys is an
array type, the semantics of op_scan() are given roughly by this
implementation:
op_scan <- function(f, init, xs = NULL, length = NULL) {
xs <- xs %||% vector("list", length)
if(!is.list(xs))
xs <- op_unstack(xs)
ys <- vector("list", length(xs))
carry <- init
for (i in seq_along(xs)) {
c(carry, y) %<-% f(carry, xs[[i]])
ys[[i]] <- y
}
list(carry, op_stack(ys))
}
The loop-carried value carry (init) must hold a fixed shape and dtype
across all iterations.
In TensorFlow, y must match carry in shape and dtype. This is not
required in other backends.
Usage
op_scan(f, init, xs = NULL, length = NULL, reverse = FALSE, unroll = 1L)
Arguments
f |
Callable defines the logic for each loop iteration. This accepts two
arguments where the first is a value of the loop carry and the
second is a slice of |
init |
The initial loop carry value. This can be a scalar, tensor, or any
nested structure. It must match the structure of the first element
returned by |
xs |
Optional value to scan along its leading axis. This can be a tensor
or any nested structure. If |
length |
Optional integer specifying the number of loop iterations.
If |
reverse |
Optional boolean specifying whether to run the scan iteration
forward or in reverse, equivalent to reversing the leading axes of
the arrays in both |
unroll |
Optional positive integer or boolean specifying how many scan
iterations to unroll within a single iteration of a loop. If an
integer is provided, it determines how many unrolled loop iterations
to run within a single rolled iteration of the loop. If a boolean is
provided, it will determine if the loop is completely unrolled
( |
Value
A pair where the first element represents the final loop carry value and
the second element represents the stacked outputs of f when scanned
over the leading axis of the inputs.
Examples
sum_fn <- function(c, x) list(c + x, c + x) init <- op_array(0L) xs <- op_array(1:5) c(carry, result) %<-% op_scan(sum_fn, init, xs) carry
## tf.Tensor(15, shape=(), dtype=int32)
result
## tf.Tensor([ 1 3 6 10 15], shape=(5), dtype=int32)
See Also
Other core ops:
op_cast()
op_cond()
op_convert_to_numpy()
op_convert_to_tensor()
op_custom_gradient()
op_dtype()
op_fori_loop()
op_is_tensor()
op_map()
op_scatter()
op_scatter_update()
op_shape()
op_slice()
op_slice_update()
op_stop_gradient()
op_switch()
op_unstack()
op_vectorized_map()
op_while_loop()
Other ops:
op_abs()
op_add()
op_all()
op_any()
op_append()
op_arange()
op_arccos()
op_arccosh()
op_arcsin()
op_arcsinh()
op_arctan()
op_arctan2()
op_arctanh()
op_argmax()
op_argmin()
op_argpartition()
op_argsort()
op_array()
op_average()
op_average_pool()
op_batch_normalization()
op_binary_crossentropy()
op_bincount()
op_broadcast_to()
op_cast()
op_categorical_crossentropy()
op_ceil()
op_cholesky()
op_clip()
op_concatenate()
op_cond()
op_conj()
op_conv()
op_conv_transpose()
op_convert_to_numpy()
op_convert_to_tensor()
op_copy()
op_correlate()
op_cos()
op_cosh()
op_count_nonzero()
op_cross()
op_ctc_decode()
op_ctc_loss()
op_cumprod()
op_cumsum()
op_custom_gradient()
op_depthwise_conv()
op_det()
op_diag()
op_diagonal()
op_diff()
op_digitize()
op_divide()
op_divide_no_nan()
op_dot()
op_dtype()
op_eig()
op_eigh()
op_einsum()
op_elu()
op_empty()
op_equal()
op_erf()
op_erfinv()
op_exp()
op_expand_dims()
op_expm1()
op_extract_sequences()
op_eye()
op_fft()
op_fft2()
op_flip()
op_floor()
op_floor_divide()
op_fori_loop()
op_full()
op_full_like()
op_gelu()
op_get_item()
op_greater()
op_greater_equal()
op_hard_sigmoid()
op_hard_silu()
op_hstack()
op_identity()
op_imag()
op_image_affine_transform()
op_image_crop()
op_image_extract_patches()
op_image_hsv_to_rgb()
op_image_map_coordinates()
op_image_pad()
op_image_resize()
op_image_rgb_to_grayscale()
op_image_rgb_to_hsv()
op_in_top_k()
op_inv()
op_irfft()
op_is_tensor()
op_isclose()
op_isfinite()
op_isinf()
op_isnan()
op_istft()
op_leaky_relu()
op_less()
op_less_equal()
op_linspace()
op_log()
op_log10()
op_log1p()
op_log2()
op_log_sigmoid()
op_log_softmax()
op_logaddexp()
op_logical_and()
op_logical_not()
op_logical_or()
op_logical_xor()
op_logspace()
op_logsumexp()
op_lstsq()
op_lu_factor()
op_map()
op_matmul()
op_max()
op_max_pool()
op_maximum()
op_mean()
op_median()
op_meshgrid()
op_min()
op_minimum()
op_mod()
op_moments()
op_moveaxis()
op_multi_hot()
op_multiply()
op_nan_to_num()
op_ndim()
op_negative()
op_nonzero()
op_norm()
op_normalize()
op_not_equal()
op_one_hot()
op_ones()
op_ones_like()
op_outer()
op_pad()
op_power()
op_prod()
op_psnr()
op_qr()
op_quantile()
op_ravel()
op_real()
op_reciprocal()
op_relu()
op_relu6()
op_repeat()
op_reshape()
op_rfft()
op_roll()
op_round()
op_rsqrt()
op_scatter()
op_scatter_update()
op_segment_max()
op_segment_sum()
op_select()
op_selu()
op_separable_conv()
op_shape()
op_sigmoid()
op_sign()
op_silu()
op_sin()
op_sinh()
op_size()
op_slice()
op_slice_update()
op_slogdet()
op_softmax()
op_softplus()
op_softsign()
op_solve()
op_solve_triangular()
op_sort()
op_sparse_categorical_crossentropy()
op_split()
op_sqrt()
op_square()
op_squeeze()
op_stack()
op_std()
op_stft()
op_stop_gradient()
op_subtract()
op_sum()
op_svd()
op_swapaxes()
op_switch()
op_take()
op_take_along_axis()
op_tan()
op_tanh()
op_tensordot()
op_tile()
op_top_k()
op_trace()
op_transpose()
op_tri()
op_tril()
op_triu()
op_unstack()
op_var()
op_vdot()
op_vectorize()
op_vectorized_map()
op_vstack()
op_where()
op_while_loop()
op_zeros()
op_zeros_like()