Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"files.associations": {
"\"*.cu\"": "\"cpp\"",
"*.cu": "cuda-cpp",
"array": "cpp",
"bit": "cpp",
"cctype": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"compare": "cpp",
"concepts": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"unordered_map": "cpp",
"vector": "cpp",
"exception": "cpp",
"functional": "cpp",
"initializer_list": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"new": "cpp",
"numbers": "cpp",
"ostream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"typeinfo": "cpp",
"utility": "cpp"
}
}
130 changes: 127 additions & 3 deletions src/kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <vector>
#include <cuda_fp16.h>
#include <cooperative_groups.h>


#include "../tester/utils.h"

Expand All @@ -17,12 +19,134 @@
* @param cols Number of columns in the matrix.
* @return The trace (sum of diagonal values) of the matrix.
*/

#define FULLMASK 0xffffffff

template <typename T>
T trace(const std::vector<T>& h_input, size_t rows, size_t cols) {
// TODO: Implement the trace function
return T(-1);
__device__ T WarpReduce(T val)
{

#pragma unroll
for (int offset = warpSize / 2; offset > 0; offset >>= 1)
{
val += __shfl_down_sync(FULLMASK, val, offset);
}
return val;
}


namespace cg = cooperative_groups;

template <typename T>
__global__ void reduce_kernel(const T* d_input, T* d_output, T* d_workspace, const size_t N)
{
cg::grid_group grid = cg::this_grid();
T sum = (T)0;

const size_t tid = threadIdx.x;
const size_t bid = blockIdx.x;
const size_t idx = tid + bid * blockDim.x;
const size_t laneID = tid % warpSize;
const size_t warpID = tid / warpSize;


for (size_t i = idx; i < N; i += gridDim.x * blockDim.x)
{
sum += d_input[i];
}

// warp 内的和
T warp_sum = WarpReduce(sum);

__shared__ T smem[32];
if (laneID == 0) smem[warpID] = warp_sum;
__syncthreads();

// 用一个warp 对整个block内的所有warp之和归约
if (warpID == 0)
{
T block_sum = (tid < ((blockDim.x + warpSize - 1) / warpSize)) ? smem[laneID] : 0;
block_sum = WarpReduce(block_sum);
if(tid == 0) smem[0] = block_sum;
}
__syncthreads();

if (tid == 0) d_workspace[bid] = smem[0];
grid.sync();

// 用一个block对d_workspace中所有元素归约
if (bid == 0)
{
T final_sum = 0;
for (size_t i = idx; i < gridDim.x; i += blockDim.x)
{
final_sum += d_workspace[i];
}
final_sum = WarpReduce(final_sum);

if (laneID == 0) smem[warpID] = final_sum;
__syncthreads();

if (warpID == 0)
{
final_sum = (tid < ((blockDim.x + warpSize - 1) / warpSize)) ? smem[laneID] : 0;
final_sum = WarpReduce(final_sum);
if (tid == 0) *d_output = final_sum;
}
}

}

template <typename T>
T trace(const std::vector<T>& h_input, size_t rows, size_t cols) {

const size_t n_diag = (rows < cols) ? rows : cols;
std::vector<T> temp;
temp.reserve(n_diag);

#pragma unroll
for (size_t i = 0; i < n_diag; i++) {
temp.push_back(h_input[(size_t)i * cols + i]);
}


size_t N = temp.size();

T *d_input, *d_output, *d_workspace;
cudaMalloc(&d_input, sizeof(T) * N);
cudaMalloc(&d_output, sizeof(T));
cudaMemcpy(d_input, temp.data(), sizeof(T) * N, cudaMemcpyHostToDevice);
cudaMemset(d_output, 0, sizeof(T));

int dev = 0;
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, dev);

int threadsPerBlock = prop.maxThreadsPerBlock;
size_t smem_size = (threadsPerBlock / 32) * sizeof(T);

int numBlocksPerSm = 0;
auto kernel_func = reduce_kernel<T>;

cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, kernel_func, threadsPerBlock, smem_size);

int maxActiveBlocks = numBlocksPerSm * prop.multiProcessorCount;
int blocks = (N + threadsPerBlock - 1) / threadsPerBlock;
if (blocks > maxActiveBlocks) blocks = maxActiveBlocks;

cudaMalloc(&d_workspace, blocks * sizeof(T));

// 参数包
void* kernelArgs[] = { &d_input, &d_output, &d_workspace, &N };

cudaLaunchCooperativeKernel((void*)kernel_func, dim3(blocks), dim3(threadsPerBlock), kernelArgs, smem_size, 0);

T res;
cudaMemcpy(&res, d_output, sizeof(T), cudaMemcpyDeviceToHost);

cudaFree(d_input); cudaFree(d_output); cudaFree(d_workspace);
return res;
}
/**
* @brief Computes flash attention for given query, key, and value tensors.
*
Expand Down
Binary file added src/kernels.o
Binary file not shown.
Binary file added test_kernels
Binary file not shown.