From f396003bc1572c3fe874a08dd236cf6a7d47bf96 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 22 Jan 2026 09:55:54 -0800 Subject: [PATCH 1/2] [slimtensor] Add CUDA Storage with DeviceTraits and memory allocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/16769 This diff adds CUDA storage infrastructure to SlimTensor, enabling GPU memory allocation and management. **Key changes:** 1. **`cuda/Guard.h`** - CUDAGuard RAII class: - Saves current CUDA device on construction, restores on destruction - Exception-safe device context switching - Constructors accept device index or Device object 2. **`core/Storage.h`** - Extended for CUDA support: - Added `DeviceTraits` specialization with: - `allocate()` - Uses cudaMalloc with CUDAGuard for device selection - `free()` - Uses cudaFree with warning on error - `memcpy()` - Supports Host↔Device and Device↔Device copies - Added `DEFAULT_CUDA_DEVICE` constant - Updated `MaybeOwningStorage` constructor to handle CUDA devices - Stub implementation when `CUDA_AVAILABLE` is not defined (throws error) ghstack-source-id: 335102161 @exported-using-ghexport Differential Revision: [D91202899](https://our.internmc.facebook.com/intern/diff/D91202899/) --- backends/aoti/slim/c10/cuda/Exception.h | 40 +++ backends/aoti/slim/c10/cuda/TARGETS | 6 + backends/aoti/slim/c10/cuda/targets.bzl | 16 + backends/aoti/slim/core/Storage.h | 141 +++++++- backends/aoti/slim/core/targets.bzl | 4 +- backends/aoti/slim/core/test/targets.bzl | 37 +- backends/aoti/slim/core/test/test_storage.cpp | 332 ++++++++++++++---- backends/cuda/runtime/TARGETS | 22 ++ 8 files changed, 518 insertions(+), 80 deletions(-) create mode 100644 backends/aoti/slim/c10/cuda/Exception.h create mode 100644 backends/aoti/slim/c10/cuda/TARGETS create mode 100644 backends/aoti/slim/c10/cuda/targets.bzl diff --git a/backends/aoti/slim/c10/cuda/Exception.h b/backends/aoti/slim/c10/cuda/Exception.h new file mode 100644 index 00000000000..33d8414e661 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/Exception.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef CUDA_AVAILABLE + +#include +#include + +#include +#include +#include + +/// Checks a CUDA expression and aborts on error. +/// @param EXPR The CUDA expression to check. +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + ET_CHECK_MSG( \ + __err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \ + } while (0) + +/// Checks a CUDA expression and logs a warning on error (non-fatal). +/// @param EXPR The CUDA expression to check. +#define ET_CUDA_LOG_WARN(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (SLIMTENSOR_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ + ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \ + } \ + } while (0) + +#endif // CUDA_AVAILABLE diff --git a/backends/aoti/slim/c10/cuda/TARGETS b/backends/aoti/slim/c10/cuda/TARGETS new file mode 100644 index 00000000000..08e83a5f3c4 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/TARGETS @@ -0,0 +1,6 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/c10/cuda/targets.bzl b/backends/aoti/slim/c10/cuda/targets.bzl new file mode 100644 index 00000000000..1d44bd1f032 --- /dev/null +++ b/backends/aoti/slim/c10/cuda/targets.bzl @@ -0,0 +1,16 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define targets for SlimTensor CUDA exception handling module.""" + + runtime.cxx_library( + name = "exception", + exported_headers = [ + "Exception.h", + ], + visibility = ["@EXECUTORCH_CLIENTS"], + exported_deps = [ + "//executorch/backends/aoti/slim/c10/macros:macros", + "//executorch/runtime/platform:platform", + ], + ) diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h index d122e86c1d4..6718f04cb51 100644 --- a/backends/aoti/slim/core/Storage.h +++ b/backends/aoti/slim/core/Storage.h @@ -10,12 +10,18 @@ #include +#ifdef CUDA_AVAILABLE +#include +#include +#endif + #include #include #include #include #include #include +#include namespace executorch::backends::aoti::slim { @@ -30,6 +36,10 @@ inline void noop(void*) {} /// Default CPU device constant. inline const c10::Device CPU_DEVICE = c10::Device(c10::DeviceType::CPU, 0); +/// Default CUDA device constant. +inline const c10::Device DEFAULT_CUDA_DEVICE = + c10::Device(c10::DeviceType::CUDA, 0); + /// DeviceTraits template for device-specific operations. /// Device-specific implementations provide allocate(), free(), and memcpy(). template @@ -74,6 +84,119 @@ struct DeviceTraits { } }; +#ifdef CUDA_AVAILABLE +/// CUDA specialization of DeviceTraits. +/// Provides CUDA memory allocation and copy operations using +/// cudaMallocAsync/cudaFreeAsync with proper stream handling. +/// +/// IMPORTANT: Callers are expected to set the correct CUDA device and stream +/// using CUDAStreamGuard before calling these methods. This is consistent +/// with PyTorch's CUDACachingAllocator design pattern where the allocator +/// assumes the caller has already set the correct device context. +template <> +struct DeviceTraits { + /// Allocates CUDA device memory on the current stream. + /// Uses cudaMallocAsync for asynchronous allocation on the stream + /// that is currently set via CUDAStreamGuard, similar to how + /// PyTorch's CUDACachingAllocator works. + /// + /// NOTE: Caller must ensure the correct device is already set via + /// CUDAStreamGuard. This function does NOT create a device guard internally. + /// + /// @param nbytes Number of bytes to allocate. + /// @param device The target CUDA device (used to get the stream). + /// @return Pointer to allocated device memory. + static void* allocate(size_t nbytes, const c10::Device& device) { + // Get the current stream for this device (set by CUDAStreamGuard if any) + // This follows PyTorch's pattern where the allocator assumes the caller + // has already set the correct device via CUDAStreamGuard. + auto stream_result = + executorch::backends::cuda::getCurrentCUDAStream(device.index()); + ET_CHECK_MSG( + stream_result.ok(), + "Failed to get current CUDA stream for device %d", + static_cast(device.index())); + + cudaStream_t stream = stream_result.get(); + void* data = nullptr; + ET_CUDA_CHECK(cudaMallocAsync(&data, nbytes, stream)); + return data; + } + + /// Frees CUDA device memory on the current stream. + /// @param ptr Pointer to device memory to free. + static void free(void* ptr) { + // Get the current stream for the current device + auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1); + if (stream_result.ok()) { + ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get())); + } else { + // Fallback to synchronous free if we can't get the stream + ET_CUDA_LOG_WARN(cudaFree(ptr)); + } + } + + /// Copies memory between CPU and CUDA or CUDA and CUDA. + /// @param dst Destination pointer. + /// @param src Source pointer. + /// @param nbytes Number of bytes to copy. + /// @param dst_device Destination device. + /// @param src_device Source device. + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const c10::Device& dst_device, + const c10::Device& src_device) { + cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; + + if (src_device.is_cpu()) { + direction = cudaMemcpyHostToDevice; + } else if (dst_device.is_cpu()) { + direction = cudaMemcpyDeviceToHost; + } else { + ET_CHECK_MSG( + src_device.index() == dst_device.index(), + "CUDA memcpy across different device indices not supported: %d != %d", + static_cast(src_device.index()), + static_cast(dst_device.index())); + } + + ET_CUDA_CHECK(cudaMemcpy(dst, src, nbytes, direction)); + } +}; +#else +/// CUDA stub when CUDA_AVAILABLE is not defined. +/// All operations abort with an error message. +template <> +struct DeviceTraits { + static void* allocate(size_t nbytes, const c10::Device& device) { + (void)nbytes; + (void)device; + ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } + + static void free(void* ptr) { + (void)ptr; + ET_LOG(Error, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } + + static void memcpy( + void* dst, + const void* src, + size_t nbytes, + const c10::Device& dst_device, + const c10::Device& src_device) { + (void)dst; + (void)src; + (void)nbytes; + (void)dst_device; + (void)src_device; + ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support"); + } +}; +#endif // CUDA_AVAILABLE + /** * MaybeOwningStorage - A storage class that manages tensor data memory. * @@ -93,17 +216,19 @@ struct DeviceTraits { class MaybeOwningStorage { public: /// Constructs owning storage with allocated memory. - /// @param device The device for storage (must be CPU). + /// @param device The device for storage (CPU or CUDA). /// @param nbytes Number of bytes to allocate. MaybeOwningStorage(const c10::Device& device, size_t nbytes) : device_(device), capacity_(nbytes), is_owning_(true) { - ET_CHECK_MSG( - device.is_cpu(), - "Only CPU device is currently supported, got: %s", - device.str().c_str()); - - data_ = DeviceTraits::allocate(nbytes, device); - deleter_ = DeviceTraits::free; + if (device.is_cpu()) { + data_ = DeviceTraits::allocate(nbytes, device); + deleter_ = DeviceTraits::free; + } else if (device.is_cuda()) { + data_ = DeviceTraits::allocate(nbytes, device); + deleter_ = DeviceTraits::free; + } else { + ET_CHECK_MSG(false, "Unsupported device type: %s", device.str().c_str()); + } } /// Default constructor is deleted - storage must have a device. diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index 2056b8c6866..0fc898c5598 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -17,10 +17,12 @@ def define_common_targets(): "//executorch/backends/aoti/slim/util:shared_ptr", "//executorch/backends/aoti/slim/util:size_util", "//executorch/runtime/platform:platform", + "//executorch/backends/aoti/slim/c10/cuda:exception", + "//executorch/backends/cuda/runtime:guard", ], ) - # Header-only library for SlimTensor + # Header-only library for SlimTensor (CPU-only for now) runtime.cxx_library( name = "slimtensor", headers = [ diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index c7debd46836..3a7e99dd37c 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -1,17 +1,36 @@ +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def get_backend_mode(): + """Get the supported backend mode of slimtensor.""" + return ["cuda", "cpu"] + def define_common_targets(): """Define test targets for SlimTensor core module.""" - runtime.cxx_test( - name = "test_storage", - srcs = [ - "test_storage.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/core:storage", - ], - ) + # GPU storage test with CUDA support + for backend_mode in get_backend_mode(): + backend_suffix = "_" + backend_mode if backend_mode == "cuda" else "" + + backend_kwargs = { + "external_deps": [("cuda", None, "cuda-lazy")], + "preprocessor_flags": ["-DCUDA_AVAILABLE=1"], + "keep_gpu_sections": True, + "remote_execution": re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), + } if backend_mode == "cuda" else {} + + runtime.cxx_test( + name = "test_storage" + backend_suffix, + srcs = [ + "test_storage.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/core:storage", + ], + **backend_kwargs + ) runtime.cxx_test( name = "test_slimtensor_basic", diff --git a/backends/aoti/slim/core/test/test_storage.cpp b/backends/aoti/slim/core/test/test_storage.cpp index 42a8678c888..5ff3d6620be 100644 --- a/backends/aoti/slim/core/test/test_storage.cpp +++ b/backends/aoti/slim/core/test/test_storage.cpp @@ -10,8 +10,29 @@ #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch::backends::aoti::slim { +// ============================================================================= +// Test Device Helpers +// ============================================================================= + +inline std::vector getTestDevices() { + std::vector devices = {CPU_DEVICE}; +#ifdef CUDA_AVAILABLE + devices.push_back(DEFAULT_CUDA_DEVICE); +#endif + return devices; +} + +inline std::string deviceToString( + const testing::TestParamInfo& info) { + return info.param.is_cpu() ? "CPU" : "CUDA"; +} + // ============================================================================= // DeviceTraits Tests // ============================================================================= @@ -52,48 +73,39 @@ TEST(DeviceTraitsCPUTest, MemcpyCPUToCPU) { } // ============================================================================= -// MaybeOwningStorage Tests - Owning Mode +// MaybeOwningStorage Parameterized Tests (CPU and CUDA) // ============================================================================= -TEST(MaybeOwningStorageTest, ConstructOwning) { +class MaybeOwningStorageParamTest : public testing::TestWithParam { + protected: + c10::Device device() const { + return GetParam(); + } +}; + +TEST_P(MaybeOwningStorageParamTest, ConstructOwning) { constexpr size_t kNbytes = 512; - MaybeOwningStorage storage(CPU_DEVICE, kNbytes); + MaybeOwningStorage storage(device(), kNbytes); EXPECT_NE(storage.data(), nullptr); EXPECT_EQ(storage.nbytes(), kNbytes); - EXPECT_TRUE(storage.device().is_cpu()); + EXPECT_EQ(storage.device().type(), device().type()); EXPECT_TRUE(storage.is_owning()); EXPECT_TRUE(storage.is_resizable()); } -TEST(MaybeOwningStorageTest, ConstructOwningZeroBytes) { - MaybeOwningStorage storage(CPU_DEVICE, 0); +TEST_P(MaybeOwningStorageParamTest, ConstructOwningZeroBytes) { + MaybeOwningStorage storage(device(), 0); EXPECT_EQ(storage.data(), nullptr); EXPECT_EQ(storage.nbytes(), 0); - EXPECT_TRUE(storage.device().is_cpu()); + EXPECT_EQ(storage.device().type(), device().type()); EXPECT_TRUE(storage.is_owning()); } -TEST(MaybeOwningStorageTest, DataPersistence) { - constexpr size_t kNumFloats = 64; - constexpr size_t kNbytes = kNumFloats * sizeof(float); - MaybeOwningStorage storage(CPU_DEVICE, kNbytes); - - float* data = static_cast(storage.data()); - for (size_t i = 0; i < kNumFloats; ++i) { - data[i] = static_cast(i) * 2.0f; - } - - float* read_data = static_cast(storage.data()); - for (size_t i = 0; i < kNumFloats; ++i) { - EXPECT_FLOAT_EQ(read_data[i], static_cast(i) * 2.0f); - } -} - -TEST(MaybeOwningStorageTest, MoveConstruct) { +TEST_P(MaybeOwningStorageParamTest, MoveConstruct) { constexpr size_t kNbytes = 256; - MaybeOwningStorage original(CPU_DEVICE, kNbytes); + MaybeOwningStorage original(device(), kNbytes); void* original_data = original.data(); MaybeOwningStorage moved(std::move(original)); @@ -101,17 +113,18 @@ TEST(MaybeOwningStorageTest, MoveConstruct) { EXPECT_EQ(moved.data(), original_data); EXPECT_EQ(moved.nbytes(), kNbytes); EXPECT_TRUE(moved.is_owning()); + EXPECT_EQ(moved.device().type(), device().type()); EXPECT_EQ(original.data(), nullptr); EXPECT_EQ(original.nbytes(), 0); EXPECT_FALSE(original.is_owning()); } -TEST(MaybeOwningStorageTest, MoveAssign) { +TEST_P(MaybeOwningStorageParamTest, MoveAssign) { constexpr size_t kNbytes1 = 256; constexpr size_t kNbytes2 = 512; - MaybeOwningStorage storage1(CPU_DEVICE, kNbytes1); - MaybeOwningStorage storage2(CPU_DEVICE, kNbytes2); + MaybeOwningStorage storage1(device(), kNbytes1); + MaybeOwningStorage storage2(device(), kNbytes2); void* storage2_data = storage2.data(); storage1 = std::move(storage2); @@ -125,7 +138,33 @@ TEST(MaybeOwningStorageTest, MoveAssign) { EXPECT_FALSE(storage2.is_owning()); } -TEST(MaybeOwningStorageTest, Clone) { +INSTANTIATE_TEST_SUITE_P( + DeviceTests, + MaybeOwningStorageParamTest, + testing::ValuesIn(getTestDevices()), + deviceToString); + +// ============================================================================= +// MaybeOwningStorage CPU-Only Tests (require direct data access) +// ============================================================================= + +TEST(MaybeOwningStorageCPUTest, DataPersistence) { + constexpr size_t kNumFloats = 64; + constexpr size_t kNbytes = kNumFloats * sizeof(float); + MaybeOwningStorage storage(CPU_DEVICE, kNbytes); + + float* data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + data[i] = static_cast(i) * 2.0f; + } + + float* read_data = static_cast(storage.data()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(read_data[i], static_cast(i) * 2.0f); + } +} + +TEST(MaybeOwningStorageCPUTest, Clone) { constexpr size_t kNumFloats = 32; constexpr size_t kNbytes = kNumFloats * sizeof(float); MaybeOwningStorage original(CPU_DEVICE, kNbytes); @@ -150,7 +189,7 @@ TEST(MaybeOwningStorageTest, Clone) { EXPECT_FLOAT_EQ(cloned_data[0], 0.0f); } -TEST(MaybeOwningStorageTest, CopyFunction) { +TEST(MaybeOwningStorageCPUTest, CopyFunction) { constexpr size_t kNumFloats = 16; constexpr size_t kNbytes = kNumFloats * sizeof(float); MaybeOwningStorage src_storage(CPU_DEVICE, kNbytes); @@ -171,23 +210,30 @@ TEST(MaybeOwningStorageTest, CopyFunction) { } // ============================================================================= -// Storage (SharedPtr) Tests +// Storage (SharedPtr) Parameterized Tests // ============================================================================= -TEST(StorageSharedPtrTest, BasicUsage) { +class StorageSharedPtrParamTest : public testing::TestWithParam { + protected: + c10::Device device() const { + return GetParam(); + } +}; + +TEST_P(StorageSharedPtrParamTest, BasicUsage) { constexpr size_t kNbytes = 128; - Storage storage(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); + Storage storage(new MaybeOwningStorage(device(), kNbytes)); EXPECT_NE(storage.get(), nullptr); EXPECT_NE(storage->data(), nullptr); EXPECT_EQ(storage->nbytes(), kNbytes); - EXPECT_TRUE(storage->device().is_cpu()); + EXPECT_EQ(storage->device().type(), device().type()); EXPECT_EQ(storage.use_count(), 1); } -TEST(StorageSharedPtrTest, SharedOwnership) { +TEST_P(StorageSharedPtrParamTest, SharedOwnership) { constexpr size_t kNbytes = 128; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); void* data_ptr = storage1->data(); Storage storage2 = storage1; // Copy, not reference - increments ref count @@ -198,7 +244,52 @@ TEST(StorageSharedPtrTest, SharedOwnership) { EXPECT_EQ(storage2->data(), data_ptr); } -TEST(StorageSharedPtrTest, SharedOwnershipModification) { +TEST_P(StorageSharedPtrParamTest, ReferenceCountDecrement) { + constexpr size_t kNbytes = 64; + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); + EXPECT_EQ(storage1.use_count(), 1); + + { + Storage storage2 = storage1; + EXPECT_EQ(storage1.use_count(), 2); + } + + EXPECT_EQ(storage1.use_count(), 1); +} + +TEST_P(StorageSharedPtrParamTest, MoveSemantics) { + constexpr size_t kNbytes = 64; + Storage storage1(new MaybeOwningStorage(device(), kNbytes)); + void* data_ptr = storage1->data(); + + Storage storage2 = std::move(storage1); + + EXPECT_EQ(storage1.get(), nullptr); + EXPECT_EQ(storage2->data(), data_ptr); + EXPECT_EQ(storage2.use_count(), 1); +} + +TEST_P(StorageSharedPtrParamTest, MakeShared) { + constexpr size_t kNbytes = 256; + Storage storage = make_shared(device(), kNbytes); + + EXPECT_NE(storage.get(), nullptr); + EXPECT_NE(storage->data(), nullptr); + EXPECT_EQ(storage->nbytes(), kNbytes); + EXPECT_EQ(storage.use_count(), 1); +} + +INSTANTIATE_TEST_SUITE_P( + DeviceTests, + StorageSharedPtrParamTest, + testing::ValuesIn(getTestDevices()), + deviceToString); + +// ============================================================================= +// Storage CPU-Only Tests (require direct data access) +// ============================================================================= + +TEST(StorageSharedPtrCPUTest, SharedOwnershipModification) { constexpr size_t kNumFloats = 8; constexpr size_t kNbytes = kNumFloats * sizeof(float); Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); @@ -208,7 +299,7 @@ TEST(StorageSharedPtrTest, SharedOwnershipModification) { data[i] = 0.0f; } - const Storage& storage2 = storage1; + Storage storage2 = storage1; float* data2 = static_cast(storage2->data()); for (size_t i = 0; i < kNumFloats; ++i) { @@ -221,39 +312,156 @@ TEST(StorageSharedPtrTest, SharedOwnershipModification) { } } -TEST(StorageSharedPtrTest, ReferenceCountDecrement) { - constexpr size_t kNbytes = 64; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); - EXPECT_EQ(storage1.use_count(), 1); +#ifdef CUDA_AVAILABLE - { - Storage storage2 = storage1; // Copy increments ref count - EXPECT_EQ(storage1.use_count(), 2); - } // storage2 destroyed, ref count decrements +// ============================================================================= +// DeviceTraits Tests +// ============================================================================= - EXPECT_EQ(storage1.use_count(), 1); +TEST(DeviceTraitsCUDATest, AllocateAndFree) { + constexpr size_t kSize = 1024; + void* ptr = + DeviceTraits::allocate(kSize, DEFAULT_CUDA_DEVICE); + ASSERT_NE(ptr, nullptr); + + DeviceTraits::free(ptr); } -TEST(StorageSharedPtrTest, MoveSemantics) { - constexpr size_t kNbytes = 64; - Storage storage1(new MaybeOwningStorage(CPU_DEVICE, kNbytes)); - void* data_ptr = storage1->data(); +TEST(DeviceTraitsCUDATest, AllocateZeroBytes) { + void* ptr = + DeviceTraits::allocate(0, DEFAULT_CUDA_DEVICE); + DeviceTraits::free(ptr); +} - Storage storage2 = std::move(storage1); +TEST(DeviceTraitsCUDATest, MemcpyCPUToCUDA) { + constexpr size_t kSize = 256; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_dst = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_verify = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); - EXPECT_EQ(storage1.get(), nullptr); - EXPECT_EQ(storage2->data(), data_ptr); - EXPECT_EQ(storage2.use_count(), 1); + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) * 2.5f; + } + + // Copy CPU -> CUDA + DeviceTraits::memcpy( + cuda_dst, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA -> CPU to verify + DeviceTraits::memcpy( + cpu_verify, + cuda_dst, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_verify[i], static_cast(i) * 2.5f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_dst); + DeviceTraits::free(cpu_verify); } -TEST(StorageSharedPtrTest, MakeShared) { - constexpr size_t kNbytes = 256; - Storage storage = make_shared(CPU_DEVICE, kNbytes); +TEST(DeviceTraitsCUDATest, MemcpyCUDAToCPU) { + constexpr size_t kSize = 128; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_mem = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_dst = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); - EXPECT_NE(storage.get(), nullptr); - EXPECT_NE(storage->data(), nullptr); - EXPECT_EQ(storage->nbytes(), kNbytes); - EXPECT_EQ(storage.use_count(), 1); + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) + 100.0f; + } + + // Copy CPU -> CUDA + DeviceTraits::memcpy( + cuda_mem, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA -> CPU + DeviceTraits::memcpy( + cpu_dst, + cuda_mem, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_dst[i], static_cast(i) + 100.0f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_mem); + DeviceTraits::free(cpu_dst); } +TEST(DeviceTraitsCUDATest, MemcpyCUDAToCUDA) { + constexpr size_t kSize = 64; + float* cpu_src = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + float* cuda_src = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cuda_dst = + static_cast(DeviceTraits::allocate( + kSize * sizeof(float), DEFAULT_CUDA_DEVICE)); + float* cpu_verify = static_cast( + DeviceTraits::allocate(kSize * sizeof(float))); + + for (size_t i = 0; i < kSize; ++i) { + cpu_src[i] = static_cast(i) * 3.0f; + } + + // Copy CPU -> CUDA src + DeviceTraits::memcpy( + cuda_src, + cpu_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + CPU_DEVICE); + + // Copy CUDA src -> CUDA dst + DeviceTraits::memcpy( + cuda_dst, + cuda_src, + kSize * sizeof(float), + DEFAULT_CUDA_DEVICE, + DEFAULT_CUDA_DEVICE); + + // Copy CUDA dst -> CPU to verify + DeviceTraits::memcpy( + cpu_verify, + cuda_dst, + kSize * sizeof(float), + CPU_DEVICE, + DEFAULT_CUDA_DEVICE); + + for (size_t i = 0; i < kSize; ++i) { + EXPECT_FLOAT_EQ(cpu_verify[i], static_cast(i) * 3.0f); + } + + DeviceTraits::free(cpu_src); + DeviceTraits::free(cuda_src); + DeviceTraits::free(cuda_dst); + DeviceTraits::free(cpu_verify); +} + +#endif // CUDA_AVAILABLE + } // namespace executorch::backends::aoti::slim diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 532ab5544ab..024418d31a6 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -3,6 +3,28 @@ load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args") oncall("executorch") +runtime.cxx_library( + name = "guard", + srcs = [ + "guard.cpp", + ], + headers = [ + "guard.h", + "utils.h", + ], + visibility = ["PUBLIC"], + deps = [ + "//executorch/runtime/platform:platform", + ], + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) + runtime.cxx_library( name = "cuda_platform", srcs = [ From 60b9279c09f092755fc642dd87719a9883821144 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 22 Jan 2026 09:55:57 -0800 Subject: [PATCH 2/2] [slimtensor] Add CUDA slimtensor creation with basic functionality Pull Request resolved: https://github.com/pytorch/executorch/pull/16770 This diff enables CUDA tensor creation with basic tensor functionality and factory function support **Key changes:* 1. **`core/SlimTensor.h`** - Extended for CUDA support: - Added `is_cuda()` method to check if tensor is on CUDA device 2. **`factory/Empty.h`** - Supports CUDA: - `empty_strided()` and `empty()` work with CUDA device via `new_storage()` - Device routing is handled by `MaybeOwningStorage` constructor ghstack-source-id: 335102160 @exported-using-ghexport Differential Revision: [D91202897](https://our.internmc.facebook.com/intern/diff/D91202897/) --- backends/aoti/slim/core/SlimTensor.h | 7 + backends/aoti/slim/core/targets.bzl | 2 +- backends/aoti/slim/core/test/targets.bzl | 21 +- .../slim/core/test/test_slimtensor_basic.cpp | 15 +- backends/aoti/slim/factory/Empty.h | 6 +- backends/aoti/slim/factory/test/targets.bzl | 37 ++- .../aoti/slim/factory/test/test_empty.cpp | 257 ++++++++++++++++++ 7 files changed, 317 insertions(+), 28 deletions(-) diff --git a/backends/aoti/slim/core/SlimTensor.h b/backends/aoti/slim/core/SlimTensor.h index f3ab9f3fec3..c662202493d 100644 --- a/backends/aoti/slim/core/SlimTensor.h +++ b/backends/aoti/slim/core/SlimTensor.h @@ -227,6 +227,13 @@ class SlimTensor { return device().is_cpu(); } + /** + * Check if the tensor is on CUDA. + */ + bool is_cuda() const { + return device().is_cuda(); + } + /** * Check if the tensor is defined (has valid storage). */ diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index 0fc898c5598..cc74b01b444 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -22,7 +22,6 @@ def define_common_targets(): ], ) - # Header-only library for SlimTensor (CPU-only for now) runtime.cxx_library( name = "slimtensor", headers = [ @@ -37,6 +36,7 @@ def define_common_targets(): "//executorch/backends/aoti/slim/c10/core:sizes_and_strides", "//executorch/backends/aoti/slim/util:array_ref_util", "//executorch/backends/aoti/slim/util:size_util", + "//executorch/backends/aoti/slim/c10/cuda:exception", "//executorch/runtime/platform:platform", ], ) diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index 3a7e99dd37c..3400fd943e8 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -32,16 +32,17 @@ def define_common_targets(): **backend_kwargs ) - runtime.cxx_test( - name = "test_slimtensor_basic", - srcs = [ - "test_slimtensor_basic.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/core:slimtensor", - "//executorch/backends/aoti/slim/core:storage", - ], - ) + runtime.cxx_test( + name = "test_slimtensor_basic" + backend_suffix, + srcs = [ + "test_slimtensor_basic.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/core:storage", + ], + **backend_kwargs + ) runtime.cxx_test( name = "test_slimtensor_copy", diff --git a/backends/aoti/slim/core/test/test_slimtensor_basic.cpp b/backends/aoti/slim/core/test/test_slimtensor_basic.cpp index dc60427c467..d70db1e4ae2 100644 --- a/backends/aoti/slim/core/test/test_slimtensor_basic.cpp +++ b/backends/aoti/slim/core/test/test_slimtensor_basic.cpp @@ -21,6 +21,9 @@ namespace executorch::backends::aoti::slim { inline std::vector get_test_devices() { std::vector devices; devices.push_back(CPU_DEVICE); +#ifdef CUDA_AVAILABLE + devices.push_back(DEFAULT_CUDA_DEVICE); +#endif return devices; } @@ -52,7 +55,9 @@ INSTANTIATE_TEST_SUITE_P( DeviceTests, SlimTensorBasicDeviceTest, ::testing::ValuesIn(get_test_devices()), - [](const ::testing::TestParamInfo& info) { return "CPU"; }); + [](const ::testing::TestParamInfo& info) { + return info.param.is_cuda() ? "CUDA" : "CPU"; + }); // ============================================================================= // Constructor Tests (Device-Parameterized) @@ -144,11 +149,11 @@ TEST_P(SlimTensorBasicDeviceTest, Dtype) { TEST_P(SlimTensorBasicDeviceTest, Device) { SlimTensor tensor = make_2x3_tensor(); - // We only support CPU for now - EXPECT_TRUE(tensor.is_cpu()); - EXPECT_EQ(tensor.device_type(), c10::DeviceType::CPU); - + // Check device type and index + EXPECT_EQ(tensor.device_type(), device().type()); EXPECT_EQ(tensor.device_index(), device().index()); + EXPECT_EQ(tensor.is_cpu(), device().is_cpu()); + EXPECT_EQ(tensor.is_cuda(), device().is_cuda()); } TEST_P(SlimTensorBasicDeviceTest, Numel) { diff --git a/backends/aoti/slim/factory/Empty.h b/backends/aoti/slim/factory/Empty.h index 24b4f53a647..c0ab9d7248d 100644 --- a/backends/aoti/slim/factory/Empty.h +++ b/backends/aoti/slim/factory/Empty.h @@ -23,7 +23,7 @@ namespace executorch::backends::aoti::slim { /// @param sizes The sizes of each dimension. /// @param strides The strides of each dimension. /// @param dtype The scalar type of tensor elements. -/// @param device The target device (must be CPU). +/// @param device The target device. /// @return A new SlimTensor with allocated but uninitialized storage. inline SlimTensor empty_strided( IntArrayRef sizes, @@ -41,7 +41,7 @@ inline SlimTensor empty_strided( /// /// @param sizes The sizes of each dimension. /// @param dtype The scalar type of tensor elements. -/// @param device The target device (must be CPU). +/// @param device The target device. /// @return A new SlimTensor with contiguous strides and uninitialized storage. inline SlimTensor empty( IntArrayRef sizes, @@ -59,7 +59,7 @@ inline SlimTensor empty( /// /// @param sizes The sizes of each dimension as an initializer list. /// @param dtype The scalar type of tensor elements. -/// @param device The target device (must be CPU). +/// @param device The target device. /// @return A new SlimTensor with contiguous strides and uninitialized storage. inline SlimTensor empty( std::initializer_list sizes, diff --git a/backends/aoti/slim/factory/test/targets.bzl b/backends/aoti/slim/factory/test/targets.bzl index a64510b2af1..7bad3067cc0 100644 --- a/backends/aoti/slim/factory/test/targets.bzl +++ b/backends/aoti/slim/factory/test/targets.bzl @@ -1,14 +1,33 @@ +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def get_backend_mode(): + """Get the supported backend mode of slimtensor.""" + return ["cuda", "cpu"] + def define_common_targets(): """Define test targets for SlimTensor factory module.""" - runtime.cxx_test( - name = "test_empty", - srcs = [ - "test_empty.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/factory:empty", - ], - ) + # GPU empty test with CUDA support + for backend_mode in get_backend_mode(): + backend_suffix = "_" + backend_mode if backend_mode == "cuda" else "" + + backend_kwargs = { + "external_deps": [("cuda", None, "cuda-lazy")], + "preprocessor_flags": ["-DCUDA_AVAILABLE=1"], + "keep_gpu_sections": True, + "remote_execution": re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), + } if backend_mode == "cuda" else {} + + runtime.cxx_test( + name = "test_empty" + backend_suffix, + srcs = [ + "test_empty.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/factory:empty", + ], + **backend_kwargs + ) diff --git a/backends/aoti/slim/factory/test/test_empty.cpp b/backends/aoti/slim/factory/test/test_empty.cpp index 7d7c9cafc34..18e7ead14ef 100644 --- a/backends/aoti/slim/factory/test/test_empty.cpp +++ b/backends/aoti/slim/factory/test/test_empty.cpp @@ -10,6 +10,10 @@ #include +#ifdef CUDA_AVAILABLE +#include +#endif + namespace executorch::backends::aoti::slim { // ============================================================================= @@ -229,4 +233,257 @@ TEST(EmptyTest, CanWriteAndReadData) { } } +#ifdef CUDA_AVAILABLE + +// ============================================================================= +// CUDA Empty Tensor Tests +// Tests are skipped at runtime if CUDA hardware is not available. +// ============================================================================= + +// ============================================================================= +// empty_strided() CUDA Tests +// ============================================================================= + +TEST(EmptyStridedCUDATest, Basic2x3Tensor) { + std::vector sizes = {2, 3}; + std::vector strides = {3, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.dim(), 2u); + EXPECT_EQ(tensor.numel(), 6u); + EXPECT_EQ(tensor.dtype(), c10::ScalarType::Float); + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_FALSE(tensor.is_cpu()); + + auto result_sizes = tensor.sizes(); + EXPECT_EQ(result_sizes[0], 2); + EXPECT_EQ(result_sizes[1], 3); + + auto result_strides = tensor.strides(); + EXPECT_EQ(result_strides[0], 3); + EXPECT_EQ(result_strides[1], 1); +} + +TEST(EmptyStridedCUDATest, ContiguousTensor) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_EQ(tensor.numel(), 24u); + EXPECT_EQ(tensor.nbytes(), 24 * sizeof(float)); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, NonContiguousTensor) { + std::vector sizes = {3, 2}; + std::vector strides = {1, 3}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_FALSE(tensor.is_contiguous()); + EXPECT_EQ(tensor.numel(), 6u); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, OneDimensional) { + std::vector sizes = {10}; + std::vector strides = {1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 1u); + EXPECT_EQ(tensor.numel(), 10u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, ZeroSizedTensor) { + std::vector sizes = {0, 3}; + std::vector strides = {3, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.numel(), 0u); + EXPECT_TRUE(tensor.is_empty()); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyStridedCUDATest, LargeDimensionalTensor) { + std::vector sizes = {2, 3, 4, 5}; + std::vector strides = {60, 20, 5, 1}; + + SlimTensor tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 4u); + EXPECT_EQ(tensor.numel(), 120u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); +} + +// ============================================================================= +// empty() CUDA Tests +// ============================================================================= + +TEST(EmptyCUDATest, BasicWithArrayRef) { + std::vector sizes = {2, 3, 4}; + + SlimTensor tensor = + empty(makeArrayRef(sizes), c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_EQ(tensor.dim(), 3u); + EXPECT_EQ(tensor.numel(), 24u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyCUDATest, VerifiesContiguousStrides) { + std::vector sizes = {2, 3, 4}; + + SlimTensor tensor = + empty(makeArrayRef(sizes), c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + auto strides = tensor.strides(); + EXPECT_EQ(strides[0], 12); + EXPECT_EQ(strides[1], 4); + EXPECT_EQ(strides[2], 1); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyCUDATest, InitializerListOverload) { + SlimTensor tensor = + empty({4, 5, 6}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 3u); + EXPECT_EQ(tensor.numel(), 120u); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_TRUE(tensor.is_cuda()); + + auto sizes = tensor.sizes(); + EXPECT_EQ(sizes[0], 4); + EXPECT_EQ(sizes[1], 5); + EXPECT_EQ(sizes[2], 6); +} + +TEST(EmptyCUDATest, OneDimensional) { + SlimTensor tensor = empty({10}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.dim(), 1u); + EXPECT_EQ(tensor.numel(), 10u); + EXPECT_EQ(tensor.stride(0), 1); + EXPECT_TRUE(tensor.is_cuda()); +} + +TEST(EmptyCUDATest, ZeroSized) { + SlimTensor tensor = + empty({0, 5}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.is_empty()); + EXPECT_EQ(tensor.numel(), 0u); + EXPECT_TRUE(tensor.is_cuda()); +} + +// ============================================================================= +// empty_like() CUDA Tests +// ============================================================================= + +TEST(EmptyLikeCUDATest, CopiesMetadata) { + std::vector sizes = {2, 3, 4}; + std::vector strides = {12, 4, 1}; + + SlimTensor original = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + SlimTensor copy = empty_like(original); + + EXPECT_EQ(copy.dim(), original.dim()); + EXPECT_EQ(copy.numel(), original.numel()); + EXPECT_EQ(copy.dtype(), original.dtype()); + EXPECT_EQ(copy.is_cuda(), original.is_cuda()); + EXPECT_EQ(copy.is_contiguous(), original.is_contiguous()); + + for (size_t i = 0; i < copy.dim(); i++) { + EXPECT_EQ(copy.size(i), original.size(i)); + EXPECT_EQ(copy.stride(i), original.stride(i)); + } +} + +TEST(EmptyLikeCUDATest, HasDifferentStorage) { + SlimTensor original = + empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + SlimTensor copy = empty_like(original); + + EXPECT_NE(original.data_ptr(), copy.data_ptr()); + EXPECT_TRUE(copy.is_cuda()); +} + +TEST(EmptyLikeCUDATest, NonContiguousTensor) { + std::vector sizes = {3, 2}; + std::vector strides = {1, 3}; + + SlimTensor original = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + SlimTensor copy = empty_like(original); + + EXPECT_FALSE(copy.is_contiguous()); + EXPECT_EQ(copy.stride(0), 1); + EXPECT_EQ(copy.stride(1), 3); + EXPECT_TRUE(copy.is_cuda()); +} + +// ============================================================================= +// CUDA Data Access Tests +// ============================================================================= + +TEST(EmptyCUDATest, DataPtrIsValid) { + SlimTensor tensor = + empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + void* data = tensor.data_ptr(); + EXPECT_NE(data, nullptr); +} + +TEST(EmptyCUDATest, DeviceIndex) { + SlimTensor tensor = + empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_EQ(tensor.device().index(), 0); +} + +#endif // CUDA_AVAILABLE + } // namespace executorch::backends::aoti::slim