我尝试实现了一个block-level的softmax:
输入shape(1000*1024),Grid(1000);Block(1024);
想让一个block处理一行数据,kernel实现如下(相关疑问见**部分):
__global__ void BlockSoftmax(const float* src, float* dst, int rows, int cols) {
__shared__ float shared_max[32];
__shared__ float shared_sum[32];
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int row = blockIdx.x;
float thread_max = -INFINITY;
float thread_sum = 0;
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < cols; i += blockDim.x * blockDim.y) {
float val = src[row * cols + i];
thread_max = fmax(thread_max, val);
}
// Reduce max using __shfl_xor_sync
for (int mask = warpSize / 2; mask > 0; mask /= 2) {
thread_max = fmax(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask));
}
if (tid % warpSize == 0) {
shared_max[tid / warpSize] = thread_max;
}
__syncthreads();
**//经过以下代码后,在我的理解是只有前32线程的寄存器里的thread_max 是正确的**
if (tid < warpSize) {
thread_max = shared_max[tid];
}
__syncthreads();
**//经过以下代码后,在我的理解是只有第一个线程的寄存器里的thread_max 是正确的**
for (int mask = warpSize / 2; mask > 0; mask /= 2) {
thread_max = fmax(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask));
}
**//所以我使用一下注释代码广播第一个线程的thread_max 给block里的其他线程**
**//但是没有下面的代码测试也是成功的,我不明白为什么,同理thread_sum也是一样的现象**
// if(tid==0){
// shared_max[0] = thread_max;
// }
// __syncthreads();
// float max_val = shared_max[0];
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < cols; i += blockDim.x * blockDim.y) {
float val = src[row * cols + i];
float exp_val = exp(val - thread_max);
dst[row * cols + i] = exp_val;
thread_sum += exp_val;
}
// Reduce sum using __shfl_xor_sync
for (int mask = warpSize / 2; mask > 0; mask /= 2) {
thread_sum += __shfl_xor_sync(0xFFFFFFFF, thread_sum, mask);
}
if (tid % warpSize == 0) {
shared_sum[tid / warpSize] = thread_sum;
}
__syncthreads();
if (tid < warpSize) {
thread_sum = shared_sum[tid];
}
__syncthreads();
for (int mask = warpSize / 2; mask > 0; mask /= 2) {
thread_sum += __shfl_xor_sync(0xFFFFFFFF, thread_sum, mask);
}
//float sum_val = thread_sum;
// if(tid==0){
// shared_sum[0] = thread_sum;
// }
// __syncthreads();
// float sum_val = shared_sum[0];
for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < cols; i += blockDim.x * blockDim.y) {
dst[row * cols + i] /= thread_sum;
}
}