Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions include/experimental/__p1684_bits/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,127 @@ namespace {
};
}

namespace impl {

template<class CArray, std::size_t ... Indices> requires (
std::is_array_v<CArray> &&
std::rank_v<CArray> == 1u
)
constexpr std::array<std::remove_all_extents_t<CArray>, std::extent_v<CArray, 0>>
carray_to_array_impl(CArray& values, std::index_sequence<Indices...>)
{
return std::array{values[Indices]...};
}

template<class CArray> requires (
std::is_array_v<CArray> &&
std::rank_v<CArray> == 1u
)
constexpr std::array<std::remove_all_extents_t<CArray>, std::extent_v<CArray, 0>>
carray_to_array(CArray& values)
{
return carray_to_array_impl(values,
std::make_index_sequence<std::extent_v<CArray, 0>>());
}

template<class ElementType, std::size_t Size, std::size_t ... Indices>
constexpr std::array<std::remove_cv_t<ElementType>, Size>
ptr_to_array_impl(ElementType values[],
std::integral_constant<std::size_t, Size>,
std::index_sequence<Indices...>)
{
static_assert(! std::is_array_v<ElementType>);
return {values[Indices]...};
}

template<std::size_t Index>
constexpr auto tail(std::index_sequence<Index>) {
return std::index_sequence<>{};
}

template<std::size_t First, std::size_t ... Rest>
constexpr auto tail(std::index_sequence<First, Rest...>) {
return std::index_sequence<Rest...>{};
}

// Forward declaration is necessary for "recursion"
// inside carray_to_array_impl.
template<class CArray> requires (
std::is_array_v<CArray> &&
std::rank_v<CArray> > 1u
)
constexpr auto
carray_to_array(CArray& values);

template<class CArray, std::size_t ... Indices> requires (
std::is_array_v<CArray> &&
std::rank_v<CArray> > 1u
)
constexpr auto
carray_to_array_impl(CArray& values, std::index_sequence<Indices...> seq)
-> std::array<
std::remove_all_extents_t<CArray>,
((std::extent_v<CArray, Indices>) * ...)
>
{
constexpr std::size_t rank = std::rank_v<CArray>;
constexpr std::size_t size = ((std::extent_v<CArray, Indices>) * ...);
if constexpr (size == 0) {
return {}; // &values[0] is UB if values has zero length
}
else {
std::array<std::remove_all_extents_t<CArray>, size> result;
auto seq_tail = tail(seq);

std::size_t curpos = 0;
for (std::size_t row = 0; row < std::extent_v<CArray, 0>; ++row) {
// For rank > 1, &values[row] is an array of one less rank, not
// a pointer to the beginning of the data. Multidimensional
// "raw" (C) arrays aren't guaranteed to be contiguous anyway,
// so we can't just copy `values` as a flat array).
std::array values_row = carray_to_array(values[row]);
for (std::size_t k = 0; k < values_row.size(); ++k, ++curpos) {
result[curpos] = values_row[k];
}
}
return result;
}
}

template<class CArray> requires (
std::is_array_v<CArray> &&
std::rank_v<CArray> > 1u
)
constexpr auto
carray_to_array(CArray& values)
{
return carray_to_array_impl(values,
std::make_index_sequence<std::rank_v<CArray>>());
}

template<class CArray, std::size_t ... Indices> requires (
std::is_array_v<CArray> &&
std::rank_v<CArray> >= 1u
)
constexpr auto
extents_of_carray_impl(CArray&, std::index_sequence<Indices...>) ->
extents<std::size_t, std::extent_v<CArray, Indices>...>
{
return {};
};

template<class CArray> requires (
std::is_array_v<CArray> &&
std::rank_v<CArray> >= 1u
)
constexpr auto
extents_of_carray(CArray& values)
{
return extents_of_carray_impl(values, std::make_index_sequence<std::rank_v<CArray>>());
};

} // namespace impl

template <
class ElementType,
class Extents,
Expand Down Expand Up @@ -243,6 +364,27 @@ class mdarray {
static_assert( std::is_constructible<extents_type, OtherExtents>::value, "");
}

#if 0
// Corresponds to deduction guide from std::initializer_list<value_type>
MDSPAN_INLINE_FUNCTION
constexpr mdarray(std::initializer_list<value_type> values)
: map_(extents_type{}), ctr_{values}
{}
#endif // 0

// Corresponds to deduction guide from C array
MDSPAN_TEMPLATE_REQUIRES(
class CArray,
/* requires */ (
std::is_array_v<CArray> &&
std::rank_v<CArray> >= 1u
)
)
MDSPAN_INLINE_FUNCTION
constexpr mdarray(CArray& values)
: map_(extents_type{}), ctr_{impl::carray_to_array(values)}
{}

MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdarray& operator= (const mdarray&) = default;
MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdarray& operator= (mdarray&&) = default;
MDSPAN_INLINE_FUNCTION_DEFAULTED
Expand Down Expand Up @@ -455,6 +597,34 @@ class mdarray {
friend class mdarray;
};

// Rank >= 1 C array -> layout_right mdarray
// with container_type = std::array
template<class CArray>
requires (std::is_array_v<CArray> && std::rank_v<CArray> >= 1u)
mdarray(CArray& values) -> mdarray<
std::remove_all_extents_t<CArray>,
decltype(impl::extents_of_carray(values)),
layout_right,
decltype(impl::carray_to_array(values))
>;

#if 0
// Adding a deduction guide from initializer_list<T> breaks the above
// C array deduction guide, because construction from
// initializer_list<T> catches every possible creation of an mdarray
// from curly braces.
//
// We may not know values.size() at compile time,
// so we have to use a dynamically allocated container.
template<class T>
mdarray(std::initializer_list<T> values) ->
mdarray<
std::remove_cvref_t<T>,
dextents<std::size_t, 1>,
layout_right,
std::vector<std::remove_cvref_t<T>>
>;
#endif // 0

} // end namespace MDSPAN_IMPL_PROPOSED_NAMESPACE
} // end namespace MDSPAN_IMPL_STANDARD_NAMESPACE
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ if(NOT MDSPAN_ENABLE_CUDA AND NOT MDSPAN_ENABLE_HIP)
mdspan_add_test(test_mdarray_ctors)
mdspan_add_test(test_mdarray_to_mdspan)
endif()
mdspan_add_test(test_mdarray_carray_ctad)

if(CMAKE_CXX_STANDARD GREATER_EQUAL 20)
if((CMAKE_CXX_COMPILER_ID STREQUAL Clang) OR ((CMAKE_CXX_COMPILER_ID STREQUAL GNU) AND (CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 12.0.0)))
Expand Down
Loading