参照示例,我写了一个HGEMM kernel,其中SMEM → RMEM使用ldmatrix.sync.aligned.x4.trans(SM75_U16x8_LDSM_T)。SMEM atom layout为:(128, 8),MN-major,使用Swizzle<3,3,4>。使用Nsight compute进行profile之后,发现还是有一些bank冲突,不知道是如何产生的,以及如何消除?
测试代码如下(只包含SMEM → RMEM):
namespace {
constexpr int constexpr_log2(int n) { return (n <= 1) ? 0 : 1 + constexpr_log2(n / 2); }
} // namespace
template <typename ASmemLayout, typename SmemTiledCopyA, typename TiledMMA>
__global__ void kernel_ldmatrix(ASmemLayout smem_layout_A, SmemTiledCopyA smem_tiled_copy_A, TiledMMA tiled_mma) {
using T = cute::half_t;
__shared__ T smem_A[cute::cosize_v<ASmemLayout>];
const int total = cute::cosize_v<ASmemLayout>;
for (int i = threadIdx.x; i < total; i += blockDim.x) {
smem_A[i] = T(static_cast<float>(i));
}
__syncthreads();
auto tensor_smem_A = cute::make_tensor(cute::make_smem_ptr(smem_A), smem_layout_A); // (bM=128, bK=32)
// ---- SMEM -> REG via ldmatrix ----
auto thr_mma = tiled_mma.get_slice(threadIdx.x);
auto mma_tCrA = thr_mma.partition_fragment_A(tensor_smem_A); // (MMA, MMA_M, MMA_K)
auto smem_thr_copy_A = smem_tiled_copy_A.get_slice(threadIdx.x);
auto smem_tCsA = smem_thr_copy_A.partition_S(tensor_smem_A); // (CPY, CPY_M, CPY_K)
auto smem_tCrA_view = smem_thr_copy_A.retile_D(mma_tCrA); // (CPY, CPY_M, CPY_K)
cute::copy(smem_tiled_copy_A, smem_tCsA, smem_tCrA_view);
if (cute::thread0()) {
printf("\n=== mma_tCrA values (thread 0, after ldmatrix) ===\n");
for (int i = 0; i < cute::size(mma_tCrA); i++) {
printf(" mma_tCrA[%2d] = %.0f\n", i, float(mma_tCrA(i)));
}
}
}
void test_ldmatrix() {
constexpr int M = 4096, K = 4096;
using T = cute::half_t;
// CTA tile:bM=128(M 方向),bK=32(K 方向)
constexpr auto bM = cute::Int<128>{};
constexpr auto bK = cute::Int<32>{};
// ---- Swizzle<BBits, MBase, SShift> 参数推导 ----
// 与 test_async_cp_ldmatrix.cu 保持一致,消除 ldmatrix 的 SMEM bank conflict
//
// MBase = log2(128bit向量宽度 / 元素宽度) = log2(16B / 2B) = 3
// BBits = log2(bM_in_bytes / vector_bytes) - ... = log2(64) - 3 = 3
// SShift = log2(bM) - MBase = 7 - 3 = 4
// 最终 Swizzle<3,3,4>
constexpr auto MBase_A = constexpr_log2(sizeof(cute::uint128_t) / sizeof(T)); // 3
constexpr auto BBits_A = constexpr_log2(32 * 4 / sizeof(T)) - MBase_A; // 3
constexpr auto SShift_A = constexpr_log2(bM) - MBase_A; // 4
constexpr auto swizzle_A = cute::Swizzle<BBits_A, MBase_A, SShift_A>{}; // Swizzle<3,3,4>
// SMEM layout:atom=(128M,8K),M-major;叠加 Swizzle 后 tile 到 (128M,32K)
constexpr auto smem_atom_layout_A = cute::make_layout(cute::make_shape(bM, cute::Int<8>{})); // (128,8):(1,128)
constexpr auto smem_atom_layout_A_swizzled = cute::composition(swizzle_A, smem_atom_layout_A);
constexpr auto smem_layout_A = cute::tile_to_shape(smem_atom_layout_A_swizzled, cute::make_shape(bM, bK));
// ---- TiledMMA:SM80_16x8x16,warp 排布 (2M,2N,1K),mma_tile=(32M,32N,16K) ----
using MMATraits = cute::MMA_Traits<cute::SM80_16x8x16_F16F16F16F16_TN>;
using MMAAtomShape = MMATraits::Shape_MNK;
constexpr auto mma_atom = cute::MMA_Atom<MMATraits>{};
constexpr auto mma_atom_shape = MMAAtomShape{};
constexpr int MMA_LAYOUT_M = 2, MMA_LAYOUT_N = 2, MMA_LAYOUT_K = 1;
constexpr int NUM_MMA_TILE_M = 1, NUM_MMA_TILE_N = 2, NUM_MMA_TILE_K = 1;
constexpr auto MMA_TILE_M = cute::get<0>(mma_atom_shape) * NUM_MMA_TILE_M * MMA_LAYOUT_M; // 16*1*2=32
constexpr auto MMA_TILE_N = cute::get<1>(mma_atom_shape) * NUM_MMA_TILE_N * MMA_LAYOUT_N; // 8*2*2=32
constexpr auto MMA_TILE_K = cute::get<2>(mma_atom_shape) * NUM_MMA_TILE_K * MMA_LAYOUT_K; // 16*1*1=16
constexpr auto mma_layout =
cute::make_layout(cute::make_shape(cute::Int<MMA_LAYOUT_M>{}, cute::Int<MMA_LAYOUT_N>{}, cute::Int<MMA_LAYOUT_K>{}));
constexpr auto mma_tile = cute::make_tile(cute::Int<MMA_TILE_M>{}, cute::Int<MMA_TILE_N>{}, cute::Int<MMA_TILE_K>{});
constexpr auto tiled_mma = cute::make_tiled_mma(mma_atom, mma_layout, mma_tile); // 128 threads
std::cout << "=== tiled_mma ===\n";
cute::print(tiled_mma);
// ---- SMEM -> REG tiled copy via ldmatrix ----
// SM75_U16x8_LDSM_T = ldmatrix.sync.aligned.x4.trans(转置加载,供 A operand 使用)
using Copy_Atom_A = cute::Copy_Atom<cute::SM75_U16x8_LDSM_T, T>;
constexpr auto smem_tiled_copy_A = cute::make_tiled_copy_A(Copy_Atom_A{}, tiled_mma);
std::cout << "\n=== smem_tiled_copy_A ===\n";
cute::print(smem_tiled_copy_A);
std::cout << "\n";
// 单个 CTA,128 线程(4 warp),grid=1
dim3 block(cute::size(tiled_mma)); // 128
dim3 grid(M / bM, K / bK);
kernel_ldmatrix<<<grid, block>>>(smem_layout_A, smem_tiled_copy_A, tiled_mma);
if (cudaError_t err = cudaDeviceSynchronize(); err != cudaSuccess) {
std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl;
}
}
int main() {
test_ldmatrix();
return 0;
}