Skip to content

Commit 59bc9c1

Browse files
committed
UMTensor: implement Array level copy operations
Mirrors the BTAS logic for now for testing, needs to be cleaned up.
1 parent 238c0a8 commit 59bc9c1

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

src/TiledArray/device/um_tensor.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
#include <TiledArray/fwd.h>
3232

33+
#include <TiledArray/conversions/to_new_tile_type.h>
3334
#include <TiledArray/device/blas.h>
3435
#include <TiledArray/device/device_array.h>
3536
#include <TiledArray/device/kernel/mult_kernel.h>
@@ -728,6 +729,122 @@ T abs_min(const UMTensor<T> &arg) {
728729
return result;
729730
}
730731

732+
/// convert array from UMTensor to TiledArray::Tensor
733+
template <typename UMT, typename TATensor, typename Policy>
734+
TiledArray::DistArray<TATensor, Policy>
735+
um_tensor_to_ta_tensor(const TiledArray::DistArray<UMT, Policy> &um_array) {
736+
if constexpr (std::is_same_v<UMT, TATensor>) {
737+
// No-op if UMTensor is the same type as TATensor type
738+
return um_array;
739+
} else {
740+
const auto convert_tile_memcpy = [](const UMT &tile) {
741+
TATensor result(tile.range());
742+
743+
auto stream = device::stream_for(result.range());
744+
DeviceSafeCall(
745+
device::memcpyAsync(result.data(), tile.data(),
746+
tile.size() * sizeof(typename TATensor::value_type),
747+
device::MemcpyDefault, stream));
748+
device::sync_madness_task_with(stream);
749+
750+
return result;
751+
};
752+
753+
const auto convert_tile_um = [](const UMT &tile) {
754+
TATensor result(tile.range());
755+
using std::begin;
756+
const auto n = tile.size();
757+
758+
auto stream = device::stream_for(tile.range());
759+
760+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>(
761+
tile, stream);
762+
763+
std::copy_n(tile.data(), n, result.data());
764+
765+
return result;
766+
};
767+
768+
const char *use_legacy_conversion =
769+
std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION");
770+
auto ta_array = use_legacy_conversion
771+
? to_new_tile_type(um_array, convert_tile_um)
772+
: to_new_tile_type(um_array, convert_tile_memcpy);
773+
774+
um_array.world().gop.fence();
775+
return ta_array;
776+
}
777+
}
778+
779+
/// convert array from TiledArray::Tensor to UMTensor
780+
template <typename UMT, typename TATensor, typename Policy>
781+
TiledArray::DistArray<UMT, Policy>
782+
ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) {
783+
if constexpr (std::is_same_v<UMT, TATensor>) {
784+
// No-op if array is the same as return type
785+
return array;
786+
} else {
787+
using inT = typename TATensor::value_type;
788+
using outT = typename UMT::value_type;
789+
// check if element conversion is necessary
790+
constexpr bool T_conversion = !std::is_same_v<inT, outT>;
791+
792+
// this is safe even when need to convert element types, but less efficient
793+
auto convert_tile_um = [](const TATensor &tile) {
794+
/// UMTensor must be wrapped into TA::Tile
795+
UMT result(tile.range());
796+
797+
const auto n = tile.size();
798+
std::copy_n(tile.data(), n, result.data());
799+
800+
auto stream = device::stream_for(result.range());
801+
802+
TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>(
803+
result, stream);
804+
805+
// N.B. move! without it have D-to-H transfer due to calling UM
806+
// allocator construct() on the host
807+
return std::move(result);
808+
};
809+
810+
TiledArray::DistArray<UMT, Policy> um_array;
811+
if constexpr (T_conversion) {
812+
um_array = to_new_tile_type(array, convert_tile_um);
813+
} else {
814+
// this is more efficient for copying:
815+
// - avoids copy on host followed by UM transfer, instead uses direct copy
816+
// - replaced unneeded copy (which also caused D-to-H transfer due to
817+
// calling UM allocator construct() on the host) by move
818+
// This eliminates all spurious UM traffic in (T) W3 contractions
819+
auto convert_tile_memcpy = [](const TATensor &tile) {
820+
/// UMTensor must be wrapped into TA::Tile .. Why?
821+
822+
auto stream = device::stream_for(tile.range());
823+
UMT result(tile.range());
824+
825+
DeviceSafeCall(
826+
device::memcpyAsync(result.data(), tile.data(),
827+
tile.size() * sizeof(typename UMT::value_type),
828+
device::MemcpyDefault, stream));
829+
830+
device::sync_madness_task_with(stream);
831+
// N.B. move! without it have D-to-H transfer due to calling UM
832+
// allocator construct() on the host
833+
return std::move(result);
834+
};
835+
836+
const char *use_legacy_conversion =
837+
std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION");
838+
um_array = use_legacy_conversion
839+
? to_new_tile_type(array, convert_tile_um)
840+
: to_new_tile_type(array, convert_tile_memcpy);
841+
}
842+
843+
array.world().gop.fence();
844+
return um_array;
845+
}
846+
}
847+
731848
} // namespace TiledArray
732849

733850
/// Serialization support

0 commit comments

Comments
 (0)