Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 74 additions & 107 deletions include/xtensor/core/xfunction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,24 +355,6 @@ namespace xt

private:

template <std::size_t... I>
layout_type layout_impl(std::index_sequence<I...>) const noexcept;

template <std::size_t... I, class... Args>
const_reference access_impl(std::index_sequence<I...>, Args... args) const;

template <std::size_t... I, class... Args>
const_reference unchecked_impl(std::index_sequence<I...>, Args... args) const;

template <std::size_t... I, class It>
const_reference element_access_impl(std::index_sequence<I...>, It first, It last) const;

template <std::size_t... I>
const_reference data_element_impl(std::index_sequence<I...>, size_type i) const;

template <class align, class requested_type, std::size_t N, std::size_t... I>
auto load_simd_impl(std::index_sequence<I...>, size_type i) const;

template <class Func, std::size_t... I>
const_stepper build_stepper(Func&& f, std::index_sequence<I...>) const noexcept;

Expand Down Expand Up @@ -437,9 +419,6 @@ namespace xt

using data_type = std::tuple<decltype(xt::linear_begin(std::declval<const std::decay_t<CT>>()))...>;

template <std::size_t... I>
reference deref_impl(std::index_sequence<I...>) const;

template <std::size_t... I>
difference_type
tuple_max_diff(std::index_sequence<I...>, const data_type& lhs, const data_type& rhs) const;
Expand Down Expand Up @@ -500,12 +479,6 @@ namespace xt

private:

template <std::size_t... I>
reference deref_impl(std::index_sequence<I...>) const;

template <class T, std::size_t... I>
simd_return_type<T> step_simd_impl(std::index_sequence<I...>);

const xfunction_type* p_f;
std::tuple<typename std::decay_t<CT>::const_stepper...> m_st;
};
Expand Down Expand Up @@ -593,7 +566,13 @@ namespace xt
template <class F, class... CT>
inline layout_type xfunction<F, CT...>::layout() const noexcept
{
return layout_impl(std::make_index_sequence<sizeof...(CT)>());
return std::apply(
[&](auto&... e)
{
return compute_layout(e.layout()...);
},
m_e
);
}

template <class F, class... CT>
Expand Down Expand Up @@ -628,7 +607,16 @@ namespace xt
{
// The static cast prevents the compiler from instantiating the template methods with signed integers,
// leading to warning about signed/unsigned conversions in the deeper layers of the access methods
return access_impl(std::make_index_sequence<sizeof...(CT)>(), static_cast<size_type>(args)...);

return std::apply(
[&](auto&... e)
{
XTENSOR_TRY(check_index(shape(), args...));
XTENSOR_CHECK_DIMENSION(shape(), args...);
return m_f(e(args...)...);
},
m_e
);
}

/**
Expand All @@ -643,7 +631,13 @@ namespace xt
template <class F, class... CT>
inline auto xfunction<F, CT...>::flat(size_type index) const -> const_reference
{
return data_element_impl(std::make_index_sequence<sizeof...(CT)>(), index);
return std::apply(
[&](auto&... e)
{
return m_f(e.data_element(index)...);
},
m_e
);
}

/**
Expand Down Expand Up @@ -671,7 +665,13 @@ namespace xt
{
// The static cast prevents the compiler from instantiating the template methods with signed integers,
// leading to warning about signed/unsigned conversions in the deeper layers of the access methods
return unchecked_impl(std::make_index_sequence<sizeof...(CT)>(), static_cast<size_type>(args)...);
return std::apply(
[&](const auto&... e)
{
return m_f(e.unchecked(static_cast<size_type>(args)...)...);
},
m_e
);
}

/**
Expand All @@ -685,7 +685,14 @@ namespace xt
template <class It>
inline auto xfunction<F, CT...>::element(It first, It last) const -> const_reference
{
return element_access_impl(std::make_index_sequence<sizeof...(CT)>(), first, last);
return std::apply(
[&](auto&... e)
{
XTENSOR_TRY(check_element_index(shape(), first, last));
return m_f(e.element(first, last)...);
},
m_e
);
}

//@}
Expand Down Expand Up @@ -819,7 +826,13 @@ namespace xt
template <class F, class... CT>
inline auto xfunction<F, CT...>::data_element(size_type i) const -> const_reference
{
return data_element_impl(std::make_index_sequence<sizeof...(CT)>(), i);
return std::apply(
[&](auto&... e)
{
return m_f(e.data_element(i)...);
},
m_e
);
}

template <class F, class... CT>
Expand All @@ -833,7 +846,13 @@ namespace xt
template <class align, class requested_type, std::size_t N>
inline auto xfunction<F, CT...>::load_simd(size_type i) const -> simd_return_type<requested_type>
{
return load_simd_impl<align, requested_type, N>(std::make_index_sequence<sizeof...(CT)>(), i);
return std::apply(
[&](auto&... e)
{
return m_f.simd_apply((e.template load_simd<align, requested_type>(i))...);
},
m_e
);
}

template <class F, class... CT>
Expand All @@ -848,55 +867,6 @@ namespace xt
return m_f;
}

template <class F, class... CT>
template <std::size_t... I>
inline layout_type xfunction<F, CT...>::layout_impl(std::index_sequence<I...>) const noexcept
{
return compute_layout(std::get<I>(m_e).layout()...);
}

template <class F, class... CT>
template <std::size_t... I, class... Args>
inline auto xfunction<F, CT...>::access_impl(std::index_sequence<I...>, Args... args) const
-> const_reference
{
XTENSOR_TRY(check_index(shape(), args...));
XTENSOR_CHECK_DIMENSION(shape(), args...);
return m_f(std::get<I>(m_e)(args...)...);
}

template <class F, class... CT>
template <std::size_t... I, class... Args>
inline auto xfunction<F, CT...>::unchecked_impl(std::index_sequence<I...>, Args... args) const
-> const_reference
{
return m_f(std::get<I>(m_e).unchecked(args...)...);
}

template <class F, class... CT>
template <std::size_t... I, class It>
inline auto xfunction<F, CT...>::element_access_impl(std::index_sequence<I...>, It first, It last) const
-> const_reference
{
XTENSOR_TRY(check_element_index(shape(), first, last));
return m_f((std::get<I>(m_e).element(first, last))...);
}

template <class F, class... CT>
template <std::size_t... I>
inline auto xfunction<F, CT...>::data_element_impl(std::index_sequence<I...>, size_type i) const
-> const_reference
{
return m_f((std::get<I>(m_e).data_element(i))...);
}

template <class F, class... CT>
template <class align, class requested_type, std::size_t N, std::size_t... I>
inline auto xfunction<F, CT...>::load_simd_impl(std::index_sequence<I...>, size_type i) const
{
return m_f.simd_apply((std::get<I>(m_e).template load_simd<align, requested_type>(i))...);
}

template <class F, class... CT>
template <class Func, std::size_t... I>
inline auto xfunction<F, CT...>::build_stepper(Func&& f, std::index_sequence<I...>) const noexcept
Expand Down Expand Up @@ -987,7 +957,13 @@ namespace xt
template <class F, class... CT>
inline auto xfunction_iterator<F, CT...>::operator*() const -> reference
{
return deref_impl(std::make_index_sequence<sizeof...(CT)>());
return std::apply(
[&](auto&... it)
{
return (p_f->m_f)(*it...);
},
m_it
);
}

template <class F, class... CT>
Expand All @@ -1010,13 +986,6 @@ namespace xt
return std::get<index>(m_it) < std::get<index>(rhs.m_it);
}

template <class F, class... CT>
template <std::size_t... I>
inline auto xfunction_iterator<F, CT...>::deref_impl(std::index_sequence<I...>) const -> reference
{
return (p_f->m_f)(*std::get<I>(m_it)...);
}

template <class F, class... CT>
template <std::size_t... I>
inline auto xfunction_iterator<F, CT...>::tuple_max_diff(
Expand Down Expand Up @@ -1140,28 +1109,26 @@ namespace xt
template <class F, class... CT>
inline auto xfunction_stepper<F, CT...>::operator*() const -> reference
{
return deref_impl(std::make_index_sequence<sizeof...(CT)>());
}

template <class F, class... CT>
template <std::size_t... I>
inline auto xfunction_stepper<F, CT...>::deref_impl(std::index_sequence<I...>) const -> reference
{
return (p_f->m_f)(*std::get<I>(m_st)...);
}

template <class F, class... CT>
template <class T, std::size_t... I>
inline auto xfunction_stepper<F, CT...>::step_simd_impl(std::index_sequence<I...>) -> simd_return_type<T>
{
return (p_f->m_f.simd_apply)(std::get<I>(m_st).template step_simd<T>()...);
return std::apply(
[&](auto&... e)
{
return (p_f->m_f)(*e...);
},
m_st
);
}

template <class F, class... CT>
template <class T>
inline auto xfunction_stepper<F, CT...>::step_simd() -> simd_return_type<T>
{
return step_simd_impl<T>(std::make_index_sequence<sizeof...(CT)>());
return std::apply(
[&](auto&... st)
{
return (p_f->m_f.simd_apply)(st.template step_simd<T>()...);
},
m_st
);
}

template <class F, class... CT>
Expand Down
Loading