求助:使用Swizzle之后,未能完全消除ldmatrix的bank conflicts

参照示例,我写了一个HGEMM kernel,其中SMEM → RMEM使用ldmatrix.sync.aligned.x4.transSM75_U16x8_LDSM_T)。SMEM atom layout为:(128, 8),MN-major,使用Swizzle<3,3,4>。使用Nsight compute进行profile之后,发现还是有一些bank冲突,不知道是如何产生的,以及如何消除?

测试代码如下(只包含SMEM → RMEM):

namespace {
constexpr int constexpr_log2(int n) { return (n <= 1) ? 0 : 1 + constexpr_log2(n / 2); }
}  // namespace

template <typename ASmemLayout, typename SmemTiledCopyA, typename TiledMMA>
__global__ void kernel_ldmatrix(ASmemLayout smem_layout_A, SmemTiledCopyA smem_tiled_copy_A, TiledMMA tiled_mma) {
  using T = cute::half_t;
  __shared__ T smem_A[cute::cosize_v<ASmemLayout>];

  const int total = cute::cosize_v<ASmemLayout>;
  for (int i = threadIdx.x; i < total; i += blockDim.x) {
    smem_A[i] = T(static_cast<float>(i));
  }
  __syncthreads();

  auto tensor_smem_A = cute::make_tensor(cute::make_smem_ptr(smem_A), smem_layout_A);  // (bM=128, bK=32)

  // ---- SMEM -> REG via ldmatrix ----
  auto thr_mma  = tiled_mma.get_slice(threadIdx.x);
  auto mma_tCrA = thr_mma.partition_fragment_A(tensor_smem_A);  // (MMA, MMA_M, MMA_K)

  auto smem_thr_copy_A = smem_tiled_copy_A.get_slice(threadIdx.x);
  auto smem_tCsA       = smem_thr_copy_A.partition_S(tensor_smem_A);  // (CPY, CPY_M, CPY_K) 
  auto smem_tCrA_view  = smem_thr_copy_A.retile_D(mma_tCrA);          // (CPY, CPY_M, CPY_K) 

  cute::copy(smem_tiled_copy_A, smem_tCsA, smem_tCrA_view);

  if (cute::thread0()) {
    printf("\n=== mma_tCrA values (thread 0, after ldmatrix) ===\n");
    for (int i = 0; i < cute::size(mma_tCrA); i++) {
      printf("  mma_tCrA[%2d] = %.0f\n", i, float(mma_tCrA(i)));
    }
  }
}
void test_ldmatrix() {
  constexpr int M = 4096, K = 4096;

  using T = cute::half_t;

  // CTA tile:bM=128(M 方向),bK=32(K 方向)
  constexpr auto bM = cute::Int<128>{};
  constexpr auto bK = cute::Int<32>{};

  // ---- Swizzle<BBits, MBase, SShift> 参数推导 ----
  // 与 test_async_cp_ldmatrix.cu 保持一致,消除 ldmatrix 的 SMEM bank conflict
  //
  // MBase = log2(128bit向量宽度 / 元素宽度) = log2(16B / 2B) = 3
  // BBits = log2(bM_in_bytes / vector_bytes) - ... = log2(64) - 3 = 3
  // SShift = log2(bM) - MBase = 7 - 3 = 4
  // 最终 Swizzle<3,3,4>
  constexpr auto MBase_A   = constexpr_log2(sizeof(cute::uint128_t) / sizeof(T));  // 3
  constexpr auto BBits_A   = constexpr_log2(32 * 4 / sizeof(T)) - MBase_A;         // 3
  constexpr auto SShift_A  = constexpr_log2(bM) - MBase_A;                         // 4
  constexpr auto swizzle_A = cute::Swizzle<BBits_A, MBase_A, SShift_A>{};          // Swizzle<3,3,4>

  // SMEM layout:atom=(128M,8K),M-major;叠加 Swizzle 后 tile 到 (128M,32K)
  constexpr auto smem_atom_layout_A          = cute::make_layout(cute::make_shape(bM, cute::Int<8>{}));  // (128,8):(1,128)
  constexpr auto smem_atom_layout_A_swizzled = cute::composition(swizzle_A, smem_atom_layout_A);
  constexpr auto smem_layout_A               = cute::tile_to_shape(smem_atom_layout_A_swizzled, cute::make_shape(bM, bK));

  // ---- TiledMMA:SM80_16x8x16,warp 排布 (2M,2N,1K),mma_tile=(32M,32N,16K) ----
  using MMATraits               = cute::MMA_Traits<cute::SM80_16x8x16_F16F16F16F16_TN>;
  using MMAAtomShape            = MMATraits::Shape_MNK;
  constexpr auto mma_atom       = cute::MMA_Atom<MMATraits>{};
  constexpr auto mma_atom_shape = MMAAtomShape{};

  constexpr int MMA_LAYOUT_M = 2, MMA_LAYOUT_N = 2, MMA_LAYOUT_K = 1;
  constexpr int NUM_MMA_TILE_M = 1, NUM_MMA_TILE_N = 2, NUM_MMA_TILE_K = 1;
  constexpr auto MMA_TILE_M = cute::get<0>(mma_atom_shape) * NUM_MMA_TILE_M * MMA_LAYOUT_M;  // 16*1*2=32
  constexpr auto MMA_TILE_N = cute::get<1>(mma_atom_shape) * NUM_MMA_TILE_N * MMA_LAYOUT_N;  // 8*2*2=32
  constexpr auto MMA_TILE_K = cute::get<2>(mma_atom_shape) * NUM_MMA_TILE_K * MMA_LAYOUT_K;  // 16*1*1=16

  constexpr auto mma_layout =
    cute::make_layout(cute::make_shape(cute::Int<MMA_LAYOUT_M>{}, cute::Int<MMA_LAYOUT_N>{}, cute::Int<MMA_LAYOUT_K>{}));
  constexpr auto mma_tile  = cute::make_tile(cute::Int<MMA_TILE_M>{}, cute::Int<MMA_TILE_N>{}, cute::Int<MMA_TILE_K>{});
  constexpr auto tiled_mma = cute::make_tiled_mma(mma_atom, mma_layout, mma_tile);  // 128 threads

  std::cout << "=== tiled_mma ===\n";
  cute::print(tiled_mma);

  // ---- SMEM -> REG tiled copy via ldmatrix ----
  // SM75_U16x8_LDSM_T = ldmatrix.sync.aligned.x4.trans(转置加载,供 A operand 使用)
  using Copy_Atom_A                = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, T>;
  constexpr auto smem_tiled_copy_A = cute::make_tiled_copy_A(Copy_Atom_A{}, tiled_mma);

  std::cout << "\n=== smem_tiled_copy_A ===\n";
  cute::print(smem_tiled_copy_A);
  std::cout << "\n";

  // 单个 CTA,128 线程(4 warp),grid=1
  dim3 block(cute::size(tiled_mma));  // 128
  dim3 grid(M / bM, K / bK);

  kernel_ldmatrix<<<grid, block>>>(smem_layout_A, smem_tiled_copy_A, tiled_mma);

  if (cudaError_t err = cudaDeviceSynchronize(); err != cudaSuccess) {
    std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl;
  }
}

int main() {
  test_ldmatrix();
  return 0;
}