Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/graph/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class GraphManager;
class GraphTensor : public Tensor {
public:
GraphTensor(const Tensor &);
void resume() const;
};

class GraphOperator {
Expand Down
2 changes: 2 additions & 0 deletions include/infinicore/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Tensor {
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<TensorImpl> impl_;
friend class TensorImpl;

void resume_from_blob_() const;
};

class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
Expand Down
10 changes: 10 additions & 0 deletions src/infinicore/context/allocators/pinnable_block_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ void PinnableBlockAllocator::deallocate(std::byte *ptr) {
}
}

size_t PinnableBlockAllocator::mark_in_use_(void *ptr, bool in_use) {
auto it = all_blocks_.find(reinterpret_cast<void *>(ptr));
if (it == all_blocks_.end()) {
throw std::runtime_error("Pointer not allocated by this allocator");
}
std::lock_guard<std::mutex> lock(mutex_);
it->second->in_use = in_use;
return it->second->size;
}

// ------------------- trim -------------------
void PinnableBlockAllocator::trim() {
std::lock_guard<std::mutex> lock(mutex_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class PinnableBlockAllocator : public MemoryAllocator {
// Switch pinned/graph mode
void set_pin_mode(bool pinned) { pinned_mode_ = pinned; }

// internal use only, force set in_use flag for a mem block
// return the size of the block
size_t mark_in_use_(void *ptr, bool in_use);

// trim cached blocks back to GPU (not pinned)
void trim();

Expand Down
7 changes: 7 additions & 0 deletions src/infinicore/context/context_impl.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "context_impl.hpp"
#include "internal.hpp"

#include "../utils.hpp"

Expand Down Expand Up @@ -194,6 +195,12 @@ void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
std::shared_ptr<graph::Graph> stopGraphRecording() {
return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording();
}

std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob) {
setDevice(blob->device());
return ContextImpl::singleton().getCurrentRuntime()->reinstantiateBlob(blob);
}

} // namespace context

} // namespace infinicore
10 changes: 10 additions & 0 deletions src/infinicore/context/internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

#include "infinicore/device.hpp"
#include "infinicore/memory.hpp"

#include "infinicore/graph/graph.hpp"

namespace infinicore::context {
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);
};
9 changes: 9 additions & 0 deletions src/infinicore/context/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ std::shared_ptr<Memory> Runtime::allocatePinnedHostMemory(size_t size) {
true);
}

std::shared_ptr<Memory> Runtime::reinstantiateBlob(std::shared_ptr<Memory> blob) {
device_memory_allocator_.get()->mark_in_use_(blob->data(), true);
return std::make_shared<Memory>(
blob->data(), blob->size(), device_,
[alloc = device_memory_allocator_.get()](std::byte *p) {
alloc->deallocate(p);
});
}

void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) {
if (async) {
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_));
Expand Down
1 change: 1 addition & 0 deletions src/infinicore/context/runtime/runtime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Runtime {

std::shared_ptr<Memory> allocateMemory(size_t size);
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size);
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);

void memcpyH2D(void *dst, const void *src, size_t size, bool async = true);
void memcpyD2H(void *dst, const void *src, size_t size);
Expand Down
4 changes: 4 additions & 0 deletions src/infinicore/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ namespace infinicore::graph {
GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) {
}

void GraphTensor::resume() const {
resume_from_blob_();
}

/* =========================
* GraphOperator
* ========================= */
Expand Down
5 changes: 5 additions & 0 deletions src/infinicore/tensor/tensor.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "infinicore/tensor.hpp"
#include "../context/internal.hpp"
#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/dtype.hpp"
Expand Down Expand Up @@ -64,6 +65,10 @@ Tensor::operator bool() const {
return impl_ != nullptr;
}

void Tensor::resume_from_blob_() const {
context::reinstantiateBlob(impl_->data_.memory);
}

TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype)
: shape(_shape), strides(_strides), dtype(_dtype) {
INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype));
Expand Down