diff --git a/include/infinicore/graph/graph.hpp b/include/infinicore/graph/graph.hpp index c63b3272d..63a3b4621 100644 --- a/include/infinicore/graph/graph.hpp +++ b/include/infinicore/graph/graph.hpp @@ -12,6 +12,7 @@ class GraphManager; class GraphTensor : public Tensor { public: GraphTensor(const Tensor &); + void resume() const; }; class GraphOperator { diff --git a/include/infinicore/tensor.hpp b/include/infinicore/tensor.hpp index e9f210186..c11dbbbab 100644 --- a/include/infinicore/tensor.hpp +++ b/include/infinicore/tensor.hpp @@ -90,6 +90,8 @@ class Tensor { Tensor(std::shared_ptr impl) : impl_(std::move(impl)) {} std::shared_ptr impl_; friend class TensorImpl; + + void resume_from_blob_() const; }; class TensorImpl : public std::enable_shared_from_this { diff --git a/src/infinicore/context/allocators/pinnable_block_allocator.cc b/src/infinicore/context/allocators/pinnable_block_allocator.cc index f41800d7c..680611cad 100644 --- a/src/infinicore/context/allocators/pinnable_block_allocator.cc +++ b/src/infinicore/context/allocators/pinnable_block_allocator.cc @@ -125,6 +125,16 @@ void PinnableBlockAllocator::deallocate(std::byte *ptr) { } } +size_t PinnableBlockAllocator::mark_in_use_(void *ptr, bool in_use) { + auto it = all_blocks_.find(reinterpret_cast(ptr)); + if (it == all_blocks_.end()) { + throw std::runtime_error("Pointer not allocated by this allocator"); + } + std::lock_guard lock(mutex_); + it->second->in_use = in_use; + return it->second->size; +} + // ------------------- trim ------------------- void PinnableBlockAllocator::trim() { std::lock_guard lock(mutex_); diff --git a/src/infinicore/context/allocators/pinnable_block_allocator.hpp b/src/infinicore/context/allocators/pinnable_block_allocator.hpp index 8911d2a6d..4ab4b4a31 100644 --- a/src/infinicore/context/allocators/pinnable_block_allocator.hpp +++ b/src/infinicore/context/allocators/pinnable_block_allocator.hpp @@ -32,6 +32,10 @@ class PinnableBlockAllocator : public MemoryAllocator { // Switch pinned/graph mode void set_pin_mode(bool pinned) { pinned_mode_ = pinned; } + // internal use only, force set in_use flag for a mem block + // return the size of the block + size_t mark_in_use_(void *ptr, bool in_use); + // trim cached blocks back to GPU (not pinned) void trim(); diff --git a/src/infinicore/context/context_impl.cc b/src/infinicore/context/context_impl.cc index 6ed79af54..67472b067 100644 --- a/src/infinicore/context/context_impl.cc +++ b/src/infinicore/context/context_impl.cc @@ -1,4 +1,5 @@ #include "context_impl.hpp" +#include "internal.hpp" #include "../utils.hpp" @@ -194,6 +195,12 @@ void addGraphOperator(std::shared_ptr op) { std::shared_ptr stopGraphRecording() { return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording(); } + +std::shared_ptr reinstantiateBlob(std::shared_ptr blob) { + setDevice(blob->device()); + return ContextImpl::singleton().getCurrentRuntime()->reinstantiateBlob(blob); +} + } // namespace context } // namespace infinicore diff --git a/src/infinicore/context/internal.hpp b/src/infinicore/context/internal.hpp new file mode 100644 index 000000000..aeecaff51 --- /dev/null +++ b/src/infinicore/context/internal.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "infinicore/device.hpp" +#include "infinicore/memory.hpp" + +#include "infinicore/graph/graph.hpp" + +namespace infinicore::context { +std::shared_ptr reinstantiateBlob(std::shared_ptr blob); +}; diff --git a/src/infinicore/context/runtime/runtime.cc b/src/infinicore/context/runtime/runtime.cc index a6dd7eb7e..5a6f5b9c3 100644 --- a/src/infinicore/context/runtime/runtime.cc +++ b/src/infinicore/context/runtime/runtime.cc @@ -77,6 +77,15 @@ std::shared_ptr Runtime::allocatePinnedHostMemory(size_t size) { true); } +std::shared_ptr Runtime::reinstantiateBlob(std::shared_ptr blob) { + device_memory_allocator_.get()->mark_in_use_(blob->data(), true); + return std::make_shared( + blob->data(), blob->size(), device_, + [alloc = device_memory_allocator_.get()](std::byte *p) { + alloc->deallocate(p); + }); +} + void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) { if (async) { INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_)); diff --git a/src/infinicore/context/runtime/runtime.hpp b/src/infinicore/context/runtime/runtime.hpp index 58d8bd424..b5a90a602 100644 --- a/src/infinicore/context/runtime/runtime.hpp +++ b/src/infinicore/context/runtime/runtime.hpp @@ -37,6 +37,7 @@ class Runtime { std::shared_ptr allocateMemory(size_t size); std::shared_ptr allocatePinnedHostMemory(size_t size); + std::shared_ptr reinstantiateBlob(std::shared_ptr blob); void memcpyH2D(void *dst, const void *src, size_t size, bool async = true); void memcpyD2H(void *dst, const void *src, size_t size); diff --git a/src/infinicore/graph/graph.cc b/src/infinicore/graph/graph.cc index 86944af36..1f1806139 100644 --- a/src/infinicore/graph/graph.cc +++ b/src/infinicore/graph/graph.cc @@ -11,6 +11,10 @@ namespace infinicore::graph { GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) { } +void GraphTensor::resume() const { + resume_from_blob_(); +} + /* ========================= * GraphOperator * ========================= */ diff --git a/src/infinicore/tensor/tensor.cc b/src/infinicore/tensor/tensor.cc index 2acc6dec8..5523fe353 100644 --- a/src/infinicore/tensor/tensor.cc +++ b/src/infinicore/tensor/tensor.cc @@ -1,4 +1,5 @@ #include "infinicore/tensor.hpp" +#include "../context/internal.hpp" #include "../utils.hpp" #include "infinicore/context/context.hpp" #include "infinicore/dtype.hpp" @@ -64,6 +65,10 @@ Tensor::operator bool() const { return impl_ != nullptr; } +void Tensor::resume_from_blob_() const { + context::reinstantiateBlob(impl_->data_.memory); +} + TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype) : shape(_shape), strides(_strides), dtype(_dtype) { INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype));