diff --git a/include/xtensor/views/index_mapper.hpp b/include/xtensor/views/index_mapper.hpp new file mode 100644 index 000000000..ee435ab43 --- /dev/null +++ b/include/xtensor/views/index_mapper.hpp @@ -0,0 +1,371 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht * + * Copyright (c) QuantStack * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XTENSOR_INDEX_MAPPER_HPP +#define XTENSOR_INDEX_MAPPER_HPP + +#include "xview.hpp" + +namespace xt +{ + + template + struct index_mapper; + + /** + * @class index_mapper + * @brief A helper class for mapping indices between views and their underlying containers. + * + * The `index_mapper` class provides functionality to convert indices from a view's coordinate system + * to the corresponding indices in the underlying container. This is particularly useful for views + * that contain integral slices (fixed indices), as these slices reduce the dimensionality of the view. + * + * @tparam UndefinedView The primary template parameter, specialized for `xt::xview` types. + * + * @note This class is specialized for `xt::xview` types only. + * Other view types will trigger a compilation error. + * + * @example + * @code + * xt::xarray a = xt::arange(24).reshape({2, 3, 4}); + * auto view1 = xt::view(a, 1, xt::all(), xt::all()); // Fixed first dimension + * index_mapper mapper; + * + * // Map view indices (i,j) to container indices (1,i,j) + * double val = mapper.map(a, view1, 0, 0); // Returns a(1, 0, 0) + * double val2 = mapper.map(a, view1, 1, 2); // Returns a(1, 1, 2) + * @endcode + */ + template + class index_mapper> + { + /// @brief Total number of explicitly passed slices in the view + static constexpr size_t n_slices = sizeof...(Slices); + + /// @brief Number of slices that are integral constants (fixed indices) + static constexpr size_t nb_integral_slices = (std::is_integral_v + ...); + + /// @brief Number of slices that are xt::newaxis (insert a dimension) + static constexpr size_t nb_new_axis_slices = (xt::detail::is_newaxis::value + ...); + + /** + * Compute how many indices are needed to address the underlying container + * when given N indices in the view. + */ + template + static constexpr size_t n_indices_full_v = size_t( + sizeof...(Indices) + nb_integral_slices - nb_new_axis_slices + ); + + public: + + /// @brief The view type this mapper works with + using view_type = xt::xview; + + ///< @brief Value type of the underlying container + using value_type = typename xt::xview::value_type; + + private: + + /// @brief Helper type alias for the I-th slice type + template + using ith_slice_type = std::tuple_element_t>; + + /// @brief True if the I-th slice is an integral slice (fixed index) + template + static consteval bool is_ith_slice_integral(); + + /// @brief True if the I-th slice is a newaxis slice + template + static consteval bool is_ith_slice_new_axis(); + + /** + * Helper metafunction to build an index_sequence that skips + * newaxis slices. + * + * The resulting sequence contains only the indices that + * correspond to real container dimensions. + */ + template + struct indices_sequence_helper + { + // we add the current axis + using not_new_axis_type = typename indices_sequence_helper::Type; + + // we skip the current axis + using new_axis_type = typename indices_sequence_helper::Type; + + // NOTE: is_ith_slice_new_axis works even if first >= sizeof...(Slices) + using Type = std::conditional_t(), new_axis_type, not_new_axis_type>; + }; + + /// @brief Base case: recursion termination + template + struct indices_sequence_helper + { + using Type = std::index_sequence; + }; + + ///< @brief Index sequence of non-newaxis slices + template + using indices_sequence = indices_sequence_helper<0, bound>::Type; + + /** + * @brief Maps an index for a specific slice to the corresponding index in the underlying container. + * + * For integral slices (fixed indices), returns the fixed index value. + * For non-integral slices (like `xt::all()`), applies the slice transformation to the index. + * + * @tparam I The slice index to map. + * @tparam Index Type of the index (must be integral). + * @param view The view object containing slice information. + * @param i The index within the slice to map. + * @return size_t The mapped index in the underlying container. + * + * @throws Assertion failure if `i != 0` for integral slices. + * @throws Assertion failure if `i >= slice.size()` for non-integral slices. + */ + template + size_t map_ith_index(const view_type& view, const Index i) const; + + /** + * @brief Maps all indices and accesses the container. + * + * @tparam Is Index sequence for parameter pack expansion. + * @param container The underlying container to access. + * @param view The view providing slice information. + * @param indices Array of indices for all slices. + * @return value_type The value at the mapped location in the container. + */ + template + value_type map_all_indices( + const UnderlyingContainer& container, + const view_type& view, + std::index_sequence, + const std::array& indices + ) const; + + /** + * @brief Maps all indices and accesses the container with bounds checking. + * + * Same as `map_all_indices` but uses `container.at()` which performs bounds checking. + * + * @tparam Is Index sequence for parameter pack expansion. + * @param container The underlying container to access. + * @param view The view providing slice information. + * @param indices Array of indices for all slices. + * @return value_type The value at the mapped location in the container. + * + * @throws std::out_of_range if any index is out of bounds. + */ + template + value_type map_at_all_indices( + const UnderlyingContainer& container, + const view_type& view, + std::index_sequence, + const std::array& indices + ) const; + + /// @brief Expand view indices into a full index array, inserting dummy indices for integral slices + template + std::array> get_indices_full(const Indices... indices) const; + + public: + + /** + * @brief Maps view indices to container indices and returns the value. + * + * Converts the provided indices (for the free dimensions of the view) to + * the corresponding indices in the underlying container and returns the value. + * + * @tparam Indices Types of the indices (must be integral). + * @param container The underlying container to access. + * @param view The view providing slice information. + * @param indices The indices for the free dimensions of the view. + * @return value_type The value at the mapped location in the container. + * + * @example + * @code + * // For view(a, 1, all(), all()): + * mapper.map(a, view, i, j); // Maps to a(1, i, j) + * @endcode + */ + template + value_type + map(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const; + + /** + * @brief Maps view indices to container indices with bounds checking. + * + * Same as `map()` but uses bounds-checked access via `container.at()`. + * + * @tparam Indices Types of the indices (must be integral). + * @param container The underlying container to access. + * @param view The view providing slice information. + * @param indices The indices for the free dimensions of the view. + * @return value_type The value at the mapped location in the container. + * + * @throws std::out_of_range if any mapped index is out of bounds. + */ + template + value_type + map_at(const UnderlyingContainer& container, const view_type& view, const Indices... indices) const; + + /// @brief Return the dimensionality of the view + size_t dimension(const UnderlyingContainer& container) const; + }; + + /******************************* + * index_mapper implementation * + *******************************/ + + template + template + consteval bool index_mapper>::is_ith_slice_integral() + { + if constexpr (I < sizeof...(Slices)) + { + return std::is_integral_v>; + } + else + { + return false; + } + } + + template + template + consteval bool index_mapper>::is_ith_slice_new_axis() + { + if constexpr (I < sizeof...(Slices)) + { + return xt::detail::is_newaxis>::value; + } + else + { + return false; + } + } + + template + template + auto + index_mapper>::get_indices_full(const Indices... indices) const + -> std::array> + { + constexpr size_t n_indices_full = n_indices_full_v; + + std::array args{size_t(indices)...}; + std::array args_full; + + const auto fill_args_full = [&args_full, &args](std::index_sequence) + { + auto it = std::cbegin(args); + + ((args_full[Is] = (is_ith_slice_integral()) ? size_t(0) : *it++), ...); + }; + + fill_args_full(std::make_index_sequence{}); + + return args_full; + } + + template + template + auto index_mapper>::map( + const UnderlyingContainer& container, + const view_type& view, + const Indices... indices + ) const -> value_type + { + constexpr size_t n_indices_full = n_indices_full_v; + + return map_all_indices(container, view, indices_sequence{}, get_indices_full(indices...)); + } + + template + template + auto index_mapper>::map_at( + const UnderlyingContainer& container, + const view_type& view, + const Indices... indices + ) const -> value_type + { + constexpr size_t n_indices_full = n_indices_full_v; + + return map_at_all_indices(container, view, indices_sequence{}, get_indices_full(indices...)); + } + + template + template + auto index_mapper>::map_all_indices( + const UnderlyingContainer& container, + const view_type& view, + std::index_sequence, + const std::array& indices + ) const -> value_type + { + return container(map_ith_index(view, indices[Is])...); + } + + template + template + auto index_mapper>::map_at_all_indices( + const UnderlyingContainer& container, + const view_type& view, + std::index_sequence, + const std::array& indices + ) const -> value_type + { + return container.at(map_ith_index(view, indices[Is])...); + } + + template + template + auto + index_mapper>::map_ith_index(const view_type& view, const Index i) const + -> size_t + { + if constexpr (I < sizeof...(Slices)) + { + // if the slice is explicitly specified, use it + using current_slice = std::tuple_element_t>; + + static_assert(not xt::detail::is_newaxis::value); + + const auto& slice = std::get(view.slices()); + + if constexpr (std::is_integral_v) + { + assert(i == 0); + return size_t(slice); + } + else + { + assert(i < slice.size()); + return size_t(slice(i)); + } + } + else + { + // else assume xt::all + return i; + } + } + + template + auto index_mapper>::dimension(const UnderlyingContainer& container + ) const -> size_t + { + return container.dimension() - nb_integral_slices + nb_new_axis_slices; + } + +} // namespace xt + +#endif // XTENSOR_INDEX_MAPPER_HPP diff --git a/test/test_xview.cpp b/test/test_xview.cpp index b2afa3c56..64452079b 100644 --- a/test/test_xview.cpp +++ b/test/test_xview.cpp @@ -28,6 +28,7 @@ #include "xtensor/generators/xbuilder.hpp" #include "xtensor/generators/xrandom.hpp" #include "xtensor/misc/xmanipulation.hpp" +#include "xtensor/views/index_mapper.hpp" #include "xtensor/views/xstrided_view.hpp" #include "xtensor/views/xview.hpp" @@ -143,6 +144,76 @@ namespace xt } } + TEST(xview_mapping, simple) + { + view_shape_type shape = {3, 4}; + xarray a(shape); + std::vector data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + std::copy(data.cbegin(), data.cend(), a.template begin()); + + auto view1 = view(a, 1, range(1, 4)); + + index_mapper mapper1; + + EXPECT_EQ(a(1, 1), mapper1.map(a, view1, 0)); + EXPECT_EQ(a(1, 2), mapper1.map(a, view1, 1)); + EXPECT_EQ(size_t(1), mapper1.dimension(a)); + XT_EXPECT_ANY_THROW(mapper1.map_at(a, view1, 10)); + + auto view0 = view(a, 0, range(0, 3)); + index_mapper mapper0; + + EXPECT_EQ(a(0, 0), mapper0.map(a, view0, 0)); + EXPECT_EQ(a(0, 1), mapper0.map(a, view0, 1)); + EXPECT_EQ(size_t(1), mapper0.dimension(a)); + + auto view2 = view(a, range(0, 2), 2); + index_mapper mapper2; + EXPECT_EQ(a(0, 2), mapper2.map(a, view2, 0)); + EXPECT_EQ(a(1, 2), mapper2.map(a, view2, 1)); + EXPECT_EQ(size_t(1), mapper2.dimension(a)); + + auto view4 = view(a, 1); + index_mapper mapper4; + EXPECT_EQ(size_t(1), mapper4.dimension(a)); + + auto view5 = view(view4, 1); + index_mapper mapper5; + EXPECT_EQ(size_t(0), mapper5.dimension(view4)); + + auto view6 = view(a, 1, all()); + index_mapper mapper6; + EXPECT_EQ(a(1, 0), mapper6.map(a, view6, 0)); + EXPECT_EQ(a(1, 1), mapper6.map(a, view6, 1)); + EXPECT_EQ(a(1, 2), mapper6.map(a, view6, 2)); + EXPECT_EQ(a(1, 3), mapper6.map(a, view6, 3)); + + auto view7 = view(a, all(), 2); + index_mapper mapper7; + EXPECT_EQ(a(0, 2), mapper7.map(a, view7, 0)); + EXPECT_EQ(a(1, 2), mapper7.map(a, view7, 1)); + EXPECT_EQ(a(2, 2), mapper7.map(a, view7, 2)); + } + + TEST(xview_mapping, indices) + { + xarray a = {{1., 2., 3.}, {4., 5., 6.}}; + + auto view1 = view(a, all(), all()); + index_mapper mapper1; + + EXPECT_EQ(a(0, 2), mapper1.map(a, view1, 0, 2)); + EXPECT_EQ(a(0, 2), mapper1.map(a, view1, 2)); + EXPECT_EQ(a(1, 2), mapper1.map(a, view1, 1, 1, 2)); + + auto view2 = view(a, all()); + index_mapper mapper2; + + EXPECT_EQ(a(0, 2), mapper2.map(a, view2, 0, 2)); + EXPECT_EQ(a(0, 2), mapper2.map(a, view2, 2)); + EXPECT_EQ(a(1, 2), mapper2.map(a, view2, 1, 1, 2)); + } + TEST(xview, negative_index) { view_shape_type shape = {3, 4}; @@ -269,6 +340,28 @@ namespace xt EXPECT_EQ(a(1, 1, 1), view1.element(idx.cbegin(), idx.cend())); } + TEST(xview_mapping, three_dimensional) + { + view_shape_type shape = {3, 4, 2}; + std::vector data = {1, 2, 3, 4, 5, 6, 7, 8, + + 9, 10, 11, 12, 21, 22, 23, 24, + + 25, 26, 27, 28, 29, 210, 211, 212}; + xarray a(shape); + std::copy(data.cbegin(), data.cend(), a.template begin()); + + auto view1 = view(a, 1, all(), all()); + index_mapper mapper1; + + EXPECT_EQ(size_t(2), mapper1.dimension(a)); + EXPECT_EQ(a(1, 0, 0), mapper1.map(a, view1, 0, 0)); + EXPECT_EQ(a(1, 0, 1), mapper1.map(a, view1, 0, 1)); + EXPECT_EQ(a(1, 1, 0), mapper1.map(a, view1, 1, 0)); + EXPECT_EQ(a(1, 1, 1), mapper1.map(a, view1, 1, 1)); + XT_EXPECT_ANY_THROW(mapper1.map_at(a, view1, 10, 10)); + } + TEST(xview, integral_count) { size_t squeeze1 = integral_count>();