|
30 | 30 |
|
31 | 31 | #include <TiledArray/fwd.h> |
32 | 32 |
|
| 33 | +#include <TiledArray/conversions/to_new_tile_type.h> |
33 | 34 | #include <TiledArray/device/blas.h> |
34 | 35 | #include <TiledArray/device/device_array.h> |
35 | 36 | #include <TiledArray/device/kernel/mult_kernel.h> |
@@ -728,6 +729,122 @@ T abs_min(const UMTensor<T> &arg) { |
728 | 729 | return result; |
729 | 730 | } |
730 | 731 |
|
| 732 | +/// convert array from UMTensor to TiledArray::Tensor |
| 733 | +template <typename UMT, typename TATensor, typename Policy> |
| 734 | +TiledArray::DistArray<TATensor, Policy> |
| 735 | +um_tensor_to_ta_tensor(const TiledArray::DistArray<UMT, Policy> &um_array) { |
| 736 | + if constexpr (std::is_same_v<UMT, TATensor>) { |
| 737 | + // No-op if UMTensor is the same type as TATensor type |
| 738 | + return um_array; |
| 739 | + } else { |
| 740 | + const auto convert_tile_memcpy = [](const UMT &tile) { |
| 741 | + TATensor result(tile.range()); |
| 742 | + |
| 743 | + auto stream = device::stream_for(result.range()); |
| 744 | + DeviceSafeCall( |
| 745 | + device::memcpyAsync(result.data(), tile.data(), |
| 746 | + tile.size() * sizeof(typename TATensor::value_type), |
| 747 | + device::MemcpyDefault, stream)); |
| 748 | + device::sync_madness_task_with(stream); |
| 749 | + |
| 750 | + return result; |
| 751 | + }; |
| 752 | + |
| 753 | + const auto convert_tile_um = [](const UMT &tile) { |
| 754 | + TATensor result(tile.range()); |
| 755 | + using std::begin; |
| 756 | + const auto n = tile.size(); |
| 757 | + |
| 758 | + auto stream = device::stream_for(tile.range()); |
| 759 | + |
| 760 | + TiledArray::to_execution_space<TiledArray::ExecutionSpace::Host>( |
| 761 | + tile, stream); |
| 762 | + |
| 763 | + std::copy_n(tile.data(), n, result.data()); |
| 764 | + |
| 765 | + return result; |
| 766 | + }; |
| 767 | + |
| 768 | + const char *use_legacy_conversion = |
| 769 | + std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION"); |
| 770 | + auto ta_array = use_legacy_conversion |
| 771 | + ? to_new_tile_type(um_array, convert_tile_um) |
| 772 | + : to_new_tile_type(um_array, convert_tile_memcpy); |
| 773 | + |
| 774 | + um_array.world().gop.fence(); |
| 775 | + return ta_array; |
| 776 | + } |
| 777 | +} |
| 778 | + |
| 779 | +/// convert array from TiledArray::Tensor to UMTensor |
| 780 | +template <typename UMT, typename TATensor, typename Policy> |
| 781 | +TiledArray::DistArray<UMT, Policy> |
| 782 | +ta_tensor_to_um_tensor(const TiledArray::DistArray<TATensor, Policy> &array) { |
| 783 | + if constexpr (std::is_same_v<UMT, TATensor>) { |
| 784 | + // No-op if array is the same as return type |
| 785 | + return array; |
| 786 | + } else { |
| 787 | + using inT = typename TATensor::value_type; |
| 788 | + using outT = typename UMT::value_type; |
| 789 | + // check if element conversion is necessary |
| 790 | + constexpr bool T_conversion = !std::is_same_v<inT, outT>; |
| 791 | + |
| 792 | + // this is safe even when need to convert element types, but less efficient |
| 793 | + auto convert_tile_um = [](const TATensor &tile) { |
| 794 | + /// UMTensor must be wrapped into TA::Tile |
| 795 | + UMT result(tile.range()); |
| 796 | + |
| 797 | + const auto n = tile.size(); |
| 798 | + std::copy_n(tile.data(), n, result.data()); |
| 799 | + |
| 800 | + auto stream = device::stream_for(result.range()); |
| 801 | + |
| 802 | + TiledArray::to_execution_space<TiledArray::ExecutionSpace::Device>( |
| 803 | + result, stream); |
| 804 | + |
| 805 | + // N.B. move! without it have D-to-H transfer due to calling UM |
| 806 | + // allocator construct() on the host |
| 807 | + return std::move(result); |
| 808 | + }; |
| 809 | + |
| 810 | + TiledArray::DistArray<UMT, Policy> um_array; |
| 811 | + if constexpr (T_conversion) { |
| 812 | + um_array = to_new_tile_type(array, convert_tile_um); |
| 813 | + } else { |
| 814 | + // this is more efficient for copying: |
| 815 | + // - avoids copy on host followed by UM transfer, instead uses direct copy |
| 816 | + // - replaced unneeded copy (which also caused D-to-H transfer due to |
| 817 | + // calling UM allocator construct() on the host) by move |
| 818 | + // This eliminates all spurious UM traffic in (T) W3 contractions |
| 819 | + auto convert_tile_memcpy = [](const TATensor &tile) { |
| 820 | + /// UMTensor must be wrapped into TA::Tile .. Why? |
| 821 | + |
| 822 | + auto stream = device::stream_for(tile.range()); |
| 823 | + UMT result(tile.range()); |
| 824 | + |
| 825 | + DeviceSafeCall( |
| 826 | + device::memcpyAsync(result.data(), tile.data(), |
| 827 | + tile.size() * sizeof(typename UMT::value_type), |
| 828 | + device::MemcpyDefault, stream)); |
| 829 | + |
| 830 | + device::sync_madness_task_with(stream); |
| 831 | + // N.B. move! without it have D-to-H transfer due to calling UM |
| 832 | + // allocator construct() on the host |
| 833 | + return std::move(result); |
| 834 | + }; |
| 835 | + |
| 836 | + const char *use_legacy_conversion = |
| 837 | + std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION"); |
| 838 | + um_array = use_legacy_conversion |
| 839 | + ? to_new_tile_type(array, convert_tile_um) |
| 840 | + : to_new_tile_type(array, convert_tile_memcpy); |
| 841 | + } |
| 842 | + |
| 843 | + array.world().gop.fence(); |
| 844 | + return um_array; |
| 845 | + } |
| 846 | +} |
| 847 | + |
731 | 848 | } // namespace TiledArray |
732 | 849 |
|
733 | 850 | /// Serialization support |
|
0 commit comments