a little question about reduce in softmax

我尝试实现了一个block-level的softmax:
输入shape(1000*1024),Grid(1000);Block(1024);
想让一个block处理一行数据,kernel实现如下(相关疑问见**部分):

__global__ void BlockSoftmax(const float* src, float* dst, int rows, int cols) {
    __shared__ float shared_max[32];
    __shared__ float shared_sum[32];

    int tid = threadIdx.x + threadIdx.y * blockDim.x;
    int row = blockIdx.x;

    float thread_max = -INFINITY;
    float thread_sum = 0;

    for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < cols; i += blockDim.x * blockDim.y) {
        float val = src[row * cols + i];
        thread_max = fmax(thread_max, val);
    }

    // Reduce max using __shfl_xor_sync
    for (int mask = warpSize / 2; mask > 0; mask /= 2) {
        thread_max = fmax(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask));
    }
   
    if (tid % warpSize == 0) {
        shared_max[tid / warpSize] = thread_max;
    }
    __syncthreads();
**//经过以下代码后,在我的理解是只有前32线程的寄存器里的thread_max 是正确的**
    if (tid < warpSize) {
        thread_max = shared_max[tid];
    }
    __syncthreads();
**//经过以下代码后,在我的理解是只有第一个线程的寄存器里的thread_max 是正确的**
    for (int mask = warpSize / 2; mask > 0; mask /= 2) {
        thread_max = fmax(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask));
    }
**//所以我使用一下注释代码广播第一个线程的thread_max 给block里的其他线程**
**//但是没有下面的代码测试也是成功的,我不明白为什么,同理thread_sum也是一样的现象** 
    // if(tid==0){
    //     shared_max[0] = thread_max;
    // }
    // __syncthreads();
    // float max_val = shared_max[0];

    for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < cols; i += blockDim.x * blockDim.y) {
        float val = src[row * cols + i];
        float exp_val = exp(val - thread_max);
        dst[row * cols + i] = exp_val;
        thread_sum += exp_val;
    }

    // Reduce sum using __shfl_xor_sync
    for (int mask = warpSize / 2; mask > 0; mask /= 2) {
        thread_sum += __shfl_xor_sync(0xFFFFFFFF, thread_sum, mask);
    }

    if (tid % warpSize == 0) {
        shared_sum[tid / warpSize] = thread_sum;
    }
    __syncthreads();

    if (tid < warpSize) {
        thread_sum = shared_sum[tid];
    }
    __syncthreads();

    for (int mask = warpSize / 2; mask > 0; mask /= 2) {
        thread_sum += __shfl_xor_sync(0xFFFFFFFF, thread_sum, mask);
    }

    //float sum_val = thread_sum;
    // if(tid==0){
    //     shared_sum[0] = thread_sum;
    // }
    // __syncthreads();
    // float sum_val = shared_sum[0];

    for (int i = threadIdx.x + threadIdx.y * blockDim.x; i < cols; i += blockDim.x * blockDim.y) {
        dst[row * cols + i] /= thread_sum;
    }
}