Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
Expand Down
8 changes: 6 additions & 2 deletions include/infinicore/ops/embedding.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

Tensor embedding(Tensor input, Tensor weight);
void embedding_(Tensor out, Tensor input, Tensor weight);
INFINICORE_GRAPH_OP_CLASS(Embedding, Tensor, const Tensor &, const Tensor &);

Tensor embedding(const Tensor &input, const Tensor &weight);
void embedding_(Tensor out, const Tensor &input, const Tensor &weight);
} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/embedding.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/layer_norm.h"
Expand Down
26 changes: 26 additions & 0 deletions include/infiniop/ops/embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef __INFINIOP_EMBEDDING_API_H__
#define __INFINIOP_EMBEDDING_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t;

__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor(
infiniopHandle_t handle,
infiniopEmbeddingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc);

__C __export infiniStatus_t infiniopEmbedding(
infiniopEmbeddingDescriptor_t desc,
void *output,
const void *input,
const void *weight,
void *stream);

__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor(
infiniopEmbeddingDescriptor_t desc);

#endif

5 changes: 2 additions & 3 deletions python/infinicore/nn/functional/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ def embedding(
and (sparse is False)
), "Unsupported parameters."

assert "cpu" == input.device.type, (
"The device of 'input' variable must be on the CPU."
)
# Note: embedding now supports device-side input for graph recording
# The C++ implementation handles both CPU and device-side inputs

if out is None:
return Tensor(_infinicore.embedding(input._underlying, weight._underlying))
Expand Down
6 changes: 6 additions & 0 deletions src/infinicore/nn/embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ Embedding::Embedding(size_t num_embeddings,
}

Tensor Embedding::forward(const Tensor &indices) const {
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
if (device_ == Device::Type::NVIDIA || device_ == Device::Type::ILUVATAR || device_ == Device::Type::METAX || device_ == Device::Type::MOORE) {
// Use op::embedding which supports device-side input and batch dimension
return op::embedding(indices->contiguous()->to(device_), weight_);
}

// Get the shape of indices
auto indices_shape = indices->shape();

Expand Down
84 changes: 16 additions & 68 deletions src/infinicore/ops/embedding/embedding.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
#include "infinicore/ops/embedding.hpp"
#include "infinicore/context/context.hpp"
#include <cstring>

#include "../../utils.hpp"

namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding);

Embedding::Embedding(Tensor out, const Tensor &input, const Tensor &weight) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, input, weight);
}

Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract
Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
void Embedding::execute(Tensor out, const Tensor &input, const Tensor &weight) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(Embedding, out, input, weight);
}

Tensor embedding(const Tensor &input, // LongTensor of arbitrary shape containing the indices to extract
const Tensor &weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
) {
auto input_shape = input->shape();
auto weight_shape = weight->shape();
// auto vocab_size = weight_shape[0];
auto embedding_dim = weight_shape[1];

// Assign memory to out variables
Expand All @@ -21,69 +30,8 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
return inputs_embeds;
}

void embedding_(Tensor out, Tensor input, Tensor weight) {
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype()));
assert(infinicore::Device::Type::CPU == input->device().getType());

auto input_shape = input->shape();
auto weight_shape = weight->shape();
auto embedding_dim = weight_shape[1];

// Calculate the number of token
Size counts = 1;
for (auto &v : input_shape) {
counts *= v;
}

// the bytes of one token
const Size bytes = dsize(weight->dtype()) * embedding_dim;
auto *weight_ptr = weight->data();
auto *out_ptr = out->data();

// copies
if (weight->device().getType() == Device::Type::CPU) {
if (infinicore::DataType::I64 == input->dtype()) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
} else if (infinicore::DataType::I32 == input->dtype()) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());

for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
}

} else {
if (infinicore::DataType::I64 == input->dtype()) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
} else if (infinicore::DataType::I32 == input->dtype()) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
}
}
}
void embedding_(Tensor out, const Tensor &input, const Tensor &weight) {
Embedding::execute(out, input, weight);
}

} // namespace infinicore::op
44 changes: 44 additions & 0 deletions src/infinicore/ops/embedding/embedding_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "../infiniop_impl.hpp"
#include "infinicore/ops/embedding.hpp"

namespace infinicore::op::embedding_impl::infiniop {

INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Embedding, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor out, input, weight;
};

void *plan(Tensor out, const Tensor &input, const Tensor &weight) {
size_t seed = hash_combine(out, input, weight);

INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, Embedding,
seed, out->desc(), input->desc(), weight->desc());

auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(out),
graph::GraphTensor(input),
graph::GraphTensor(weight)};

return planned;
}

void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);

INFINICORE_CHECK_ERROR(infiniopEmbedding(
planned->descriptor->desc,
planned->out->data(), planned->input->data(), planned->weight->data(), context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Embedding, &plan, &run, cleanup);

} // namespace infinicore::op::embedding_impl::infiniop
109 changes: 109 additions & 0 deletions src/infiniop/ops/embedding/cpu/embedding_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include "embedding_cpu.h"
#include "../../../../utils.h"
#include "../../../handle.h"
#include "../../../tensor.h"
#include <cstring>

namespace op::embedding::cpu {

struct Descriptor::Opaque {};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc) {

auto input_shape = input_desc->shape();
auto weight_shape = weight_desc->shape();

CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);

auto output_shape = output_desc->shape();
size_t embedding_dim = weight_shape[1];
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);

for (size_t i = 0; i < input_shape.size(); ++i) {
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
}

auto input_dtype = input_desc->dtype();
auto weight_dtype = weight_desc->dtype();
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);

size_t num_indices = 1;
for (auto dim : input_shape) {
num_indices *= dim;
}

size_t vocab_size = weight_shape[0];

*desc_ptr = new Descriptor(
num_indices,
embedding_dim,
vocab_size,
input_dtype,
weight_dtype,
new Opaque{},
handle->device,
handle->device_id);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *output,
const void *input,
const void *weight,
void *stream) const {

if (_num_indices == 0) {
return INFINI_STATUS_SUCCESS;
}

size_t element_size = infiniSizeOf(_weight_dtype);
size_t row_bytes = _embedding_dim * element_size;

if (_input_dtype == INFINI_DTYPE_I32) {
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
const std::byte *weight_ptr = reinterpret_cast<const std::byte *>(weight);
std::byte *out_ptr = reinterpret_cast<std::byte *>(output);

for (size_t i = 0; i < _num_indices; ++i) {
int32_t idx = indices_ptr[i];
if (idx >= 0 && static_cast<size_t>(idx) < _vocab_size) {
std::memcpy(out_ptr + i * row_bytes,
weight_ptr + static_cast<size_t>(idx) * row_bytes,
row_bytes);
}
}
} else if (_input_dtype == INFINI_DTYPE_I64) {
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
const std::byte *weight_ptr = reinterpret_cast<const std::byte *>(weight);
std::byte *out_ptr = reinterpret_cast<std::byte *>(output);

for (size_t i = 0; i < _num_indices; ++i) {
int64_t idx = indices_ptr[i];
if (idx >= 0 && static_cast<size_t>(idx) < _vocab_size) {
std::memcpy(out_ptr + i * row_bytes,
weight_ptr + static_cast<size_t>(idx) * row_bytes,
row_bytes);
}
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}

} // namespace op::embedding::cpu
8 changes: 8 additions & 0 deletions src/infiniop/ops/embedding/cpu/embedding_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __EMBEDDING_CPU_H__
#define __EMBEDDING_CPU_H__

#include "../embedding.h"

DESCRIPTOR(cpu)

#endif // __EMBEDDING_CPU_H__
Loading