Sgemm [128,128,8]

主要参考 论文 Huang, 2018 (arxiv.org)

性能可达到 cublas的 96%
目前只贴下源码,注释还是蛮多的。
之前搞了个分支,速度下降了10%,有些过分了。

#include <algorithm>
#include <cublas_v2.h>
#include <cuda_device_runtime_api.h>
#include <device_launch_parameters.h>
#include <iomanip>
#include <iostream>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <vector>
/**
 * @brief
 *  gemm: C= alpha*A*B+beta*C
 * 每个 thread block计算 C 的tile是 [128*128]
 * 子迭代计算 Atile[128,8] Btile[8,128]
 * 所有矩阵按照行主序计算。
 * 是有数据预取,向量化读取等
 * block的threads是 256个
 * 一个 warp要计算的tile是 [32,64]
 * 一个 thread计算多 tile是 [8,8]
 * 因为要读取合并,对C采用循环分发的方式
 */

/**********************/
/* cuBLAS ERROR CHECK */
/**********************/
#ifndef cublasSafeCall
#define cublasSafeCall(err) __cublasSafeCall(err, __FILE__, __LINE__)
#endif

inline void __cublasSafeCall(cublasStatus_t err, const char *file,
                             const int line)
{
    if (CUBLAS_STATUS_SUCCESS != err)
    {
        fprintf(stderr,
                "CUBLAS error in file '%s', line %d\n \nerror %d \nterminating!\n",
                __FILE__, __LINE__, err);
        // getch();
        cudaDeviceReset();
        assert(0);
    }
}
template <typename T>
struct Type4;

template <>
struct Type4<float>
{
    using type = float4;
};

template <typename T>
using Type4t = typename Type4<T>::type;

#define A(i, j) A[(i)*lda + (j)]
#define B(i, j) B[(i)*ldb + (j)]
#define C(i, j) C[(i)*ldc + (j)]

#define ptrA(i, j) ptrA[(i)*lda + (j)]
#define ptrB(i, j) ptrB[(i)*ldb + (j)]
#define MS 128
#define NS 128
#define KS 8

template <typename T>
__device__ __forceinline__ void
vscal_fma(Type4t<T> &dst_vec, const Type4t<T> &src_vec, const T &scale)
{
    dst_vec.x += src_vec.x * scale;
    dst_vec.y += src_vec.y * scale;
    dst_vec.z += src_vec.z * scale;
    dst_vec.w += src_vec.w * scale;
}

template <typename T>
__device__ __forceinline__ void simd_axpby(Type4t<T> &dst_vec, T alpha,
                                           const Type4t<T> &srca_vec, T beta,
                                           const Type4t<T> &srcb_vec)
{
    dst_vec.x = alpha * srca_vec.x + beta * srcb_vec.x;
    dst_vec.y = alpha * srca_vec.y + beta * srcb_vec.y;
    dst_vec.z = alpha * srca_vec.z + beta * srcb_vec.z;
    dst_vec.w = alpha * srca_vec.w + beta * srcb_vec.w;
}

template <typename T>
__device__ __forceinline__ void vload(Type4t<T> &dst_vec, const T *addr)
{
    dst_vec = *((Type4t<T> *)(addr));
}

template <typename T>
__device__ __forceinline__ void vstore(T *addr, const Type4t<T> &src_vec)
{
    *((Type4t<T> *)(addr)) = src_vec;
}

__device__ __forceinline__ void print(const float4 &vec)
{
    printf("%f, %f, %f, %f\n", vec.x, vec.y, vec.z, vec.w);
}
/**
 * template <typename AccessType>
struct global_load<AccessType,
                   16
                  > {
  CUTLASS_DEVICE
  global_load(AccessType &D, void const *ptr, bool pred_guard) {
  uint4 &data = reinterpret_cast<uint4 &>(D);
    asm volatile(
        "{\n"
        "  .reg .pred p;\n"
        "  setp.ne.b32 p, %5, 0;\n"
        "  mov.b32 %0, %6;\n"
        "  mov.b32 %1, %7;\n"
        "  mov.b32 %2, %8;\n"
        "  mov.b32 %3, %9;\n"
#if CUTLASS_ENABLE_L2_PREFETCH
        "  @p ld.global.L2::128B.v4.u32 {%0, %1, %2, %3}, [%4];\n"
#else
        "  @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
#endif
        "}\n"
        : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
        : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z),
"r"(data.w));
  }
};

**/
__global__ __launch_bounds__(256)
    // assert M%128 ==0 && N%128 ==0 && K%8 ==0
    void gemm_kernel(int M, int N, int K, float alpha, const float *A,
                     const float *B, float beta, float *C)
{

    const int lda = K;
    const int ldb = N;
    const int ldc = N;
    const int tx = threadIdx.x;
    const int bx = blockIdx.x, by = blockIdx.y;
    // bx 是对应col,by是对应 row 的 行主序

    const int warp_id = tx >> 5;
    const int lane_id = tx & 31;
    // 共 8个warp, 分成 4行 2列
    // 因为 一个warp处理的元素是[32,64] 128/32 = 4,128/64=2 ,所以需要
    // 4行warp,2列 warp。
    const int warp_col = warp_id & 1;
    const int warp_row = warp_id >> 1;
    // 一个warp负责 4* 8,8* 8 的 Ctile //但需要两次读取
    //  lane_id 分成 [4,8]的方块布局
    const int col_w = lane_id & 7;
    const int row_w = lane_id >> 3;
    //每个 tx 首先由 8行4列的C tile,再循环分发一次,共 8行8列
    // 因为 C是行主序,col索引是连续的,连续的col索引应该给相邻的tx(float4 是
    // 16B,不满足32B的最小L1,L2内存交易单位)
    // C的列索引需要循环分发,行索引不需要循环分发[ col0-col3 ,col64-col67:
    // row0-row7], 一个warp是 [32,64] 所以 rowc 中 warp_row<< 5 ,warp_col<<6
    //但列需要循环分发,要分成两次读取,因此这里是  (warp_col << 5) + (col_w << 2)
    //;
    const int row_c = (warp_row << 5) + (row_w << 3); // row0-row7
    const int col_c = (warp_col << 5) + (col_w << 2); // col0-col3 ;col64-col67
    // col_w<< 2 是因为 一个tx一次读写or 4列。需要循环分发,col0-col3 ;col64-col67

    // rowa rowb cola colb 用于从global memory读取到share memory
    // Atile [128,8] 每个线程读写到 share memory是 4个值,
    // 256个 tx 分成 [128,2]布局
    const int col_a = (tx & 1) << 2;
    const int row_a = tx >> 1;
    // Btile [8,128] 每个线程读写到 share memory是 4个值
    // 256个 tx 分成 [8,32] 布局
    const int col_b = (tx & 31) << 2; // 4个值所以是 <<2
    const int row_b = tx >> 5;

    // 该block处理的 A,B,C相对地址
    A += (by << 7) * lda;
    B += (bx << 7);
    C += (by << 7) * ldc + (bx << 7);

    __shared__ float
        smema[2][8][128]; // Atile[128][8]但需要转置以方便计算的时候读取A的4行1列
                          // //8行1列
    __shared__ float smemb[2][8]
                          [128]; // Btile[8][128] 计算的时候读取 B的 1行4列  //
                                 // 1行 8列 // 2是双缓存策略以减少 一次sync
    // auto *ptr_smema = &smema[0];
    // float *[8][128] type(ptr_smema)
    // auto *ptr_smemb = &smemb[0];
    float4 Av1[2], Av2[2], Bv1[2], Bv2[2], Cv[16], Cres[16];
    // Av[2]正好8个值,Av1[0],Av1[1]用于预取策略的交换。 Cv[16]共 16*4= 8*8个值。
    float4 pref_Av, pref_Bv; //从global memory中预取的值
                             // 不直接使用 A,B 大概是因为
                             // 预取的时候用过Alia防止最后一次循环预取越界的情况。
    memset(Cres, 0, sizeof(Cres));
    // 循环之前先来次预取, 分别是 global memory的预取,和share memory的预取
    vload(pref_Av, &(A(row_a, col_a)));
    vload(pref_Bv, &(B(row_b, col_b)));
    int buffer_switch = 0; // two sharememoy buffer switch

    vstore(&(smemb[0][row_b][col_b]), pref_Bv);
    smema[0][col_a][row_a] = pref_Av.x;
    smema[0][col_a + 1][row_a] = pref_Av.y;
    smema[0][col_a + 2][row_a] = pref_Av.z;
    smema[0][col_a + 3][row_a] = pref_Av.w;
    // 写入到sharememoy后 进行sync

    __syncthreads();

    //读取 Atile [row_c,k] k(0-7), 因为转置,Atil的row是连续的
    //读取 Btile [k,col_c] col是循环分发的,分配到的 colc 0-3,64-67
    vload(Av1[0], &smema[0][0][row_c]);
    vload(Av2[0], &smema[0][0][row_c + 4]);
    vload(Bv1[0], &smemb[0][0][col_c]);
    vload(Bv2[0], &smemb[0][0][col_c + 64]);
    for (int global_k = KS; global_k < K; global_k += KS)
    {

        // global_k起始值是 KS, 因为循环之前已经预取了一次,为避免global
        // access越界,这里只到global_k<K,也可以global_k<=K,然后 取余 内存值
        // global_k%K 把其限制在 [0,K)里
        //但这样会计算 k_iteration-1次,还有一次剩余的需要计算
        // main loop 在 K上循环,每次迭代 KS

        //读取下一次计算的global memory
        //重新定位 ptrA,ptrB等

        // 这么大块的语句会不会分支跳转??
        A += KS;
        // global_K为上移动
        B += (KS * ldb);
        vload(pref_Av, &(A(row_a, col_a)));
        // global memory 的读取实际上一个tx 一次,[128,8] [8,128]
        // (整个Block的tx映射到 [128,8] [8,128])
        vload(pref_Bv, &(B(row_b, col_b)));
        // 最后一次 迭代就会越界, 这里条件判断还是用  @p
        // 比较靠谱,不如直接ptx汇编,要不类似 cutlass那样inline ptx gloabal load

        int reg_switch = 0;
#pragma unroll
        for (int inner_k_count = 0; inner_k_count < KS; inner_k_count++)
        {
            int next_inner_k_count = (inner_k_count + 1) & 7; //(1...7;1)取余
            // prefecth data from smem  to register for nex iter compute
            int next_reg = reg_switch ^ 1;
            vload(Av1[next_reg], &smema[buffer_switch][next_inner_k_count][row_c]);
            vload(Av2[next_reg],
                  &smema[buffer_switch][next_inner_k_count][row_c + 4]);
            vload(Bv1[next_reg], &smemb[buffer_switch][next_inner_k_count][col_c]);
            vload(Bv2[next_reg],
                  &smemb[buffer_switch][next_inner_k_count][col_c + 64]);
            // next_inner_k_count& 1用以切换预取的register
            //行主序,一行是连续的则 索引col是连续的
            vscal_fma(Cres[0], Bv1[reg_switch], Av1[reg_switch].x);

            vscal_fma(Cres[1], Bv1[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc0-colc3]
            vscal_fma(Cres[2], Bv1[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc0-colc3]
            vscal_fma(Cres[3], Bv1[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc0-colc3]
            vscal_fma(Cres[4], Bv1[reg_switch], Av2[reg_switch].x);
            // [row_c+4,col_c0],[row_c,colc_1],[row_c,colc_2],[row_c,colc_3]
            vscal_fma(Cres[5], Bv1[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc0-colc3]
            vscal_fma(Cres[6], Bv1[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc0-colc3]
            vscal_fma(Cres[7], Bv1[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc0-colc3]
            vscal_fma(Cres[8], Bv2[reg_switch], Av1[reg_switch].x);

            // [row_c,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[9], Bv2[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc64-colc67]
            vscal_fma(Cres[10], Bv2[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc64-colc67]
            vscal_fma(Cres[11], Bv2[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc64-colc67]
            vscal_fma(
                Cres[12], Bv2[reg_switch],
                Av2[reg_switch]
                    .x); // [row_c+4,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[13], Bv2[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc64-colc67]
            vscal_fma(Cres[14], Bv2[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc64-colc67]
            vscal_fma(Cres[15], Bv2[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc64--colc67]
            reg_switch ^= 1;
        }

        buffer_switch ^= 1;
        // two sharememoy buffer switch
        // store memoy in buffer
        vstore(&(smemb[buffer_switch][row_b][col_b]), pref_Bv);
        smema[buffer_switch][col_a][row_a] = pref_Av.x;
        smema[buffer_switch][col_a + 1][row_a] = pref_Av.y;
        smema[buffer_switch][col_a + 2][row_a] = pref_Av.z;
        smema[buffer_switch][col_a + 3][row_a] = pref_Av.w;
        __syncthreads();
        //从 sharememoy 读值
        vload(Av1[0], &smema[buffer_switch][0][row_c]);
        vload(Av2[0], &smema[buffer_switch][0][row_c + 4]);
        vload(Bv1[0], &smemb[buffer_switch][0][col_c]);
        vload(Bv2[0], &smemb[buffer_switch][0][col_c + 64]);

        // 为下一次子循环预取值
    }

// 这个版本去掉了分支,然后手动加一次计算子循环,性能提升了10%
int reg_switch = 0;
#pragma unroll
        for (int inner_k_count = 0; inner_k_count < KS; inner_k_count++)
        {
            int next_inner_k_count = (inner_k_count + 1) & 7; //(1...7;1)取余
            // prefecth data from smem  to register for nex iter compute
            int next_reg = reg_switch ^ 1;
            vload(Av1[next_reg], &smema[buffer_switch][next_inner_k_count][row_c]);
            vload(Av2[next_reg],
                  &smema[buffer_switch][next_inner_k_count][row_c + 4]);
            vload(Bv1[next_reg], &smemb[buffer_switch][next_inner_k_count][col_c]);
            vload(Bv2[next_reg],
                  &smemb[buffer_switch][next_inner_k_count][col_c + 64]);
            // next_inner_k_count& 1用以切换预取的register
            //行主序,一行是连续的则 索引col是连续的
            vscal_fma(Cres[0], Bv1[reg_switch], Av1[reg_switch].x);

            vscal_fma(Cres[1], Bv1[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc0-colc3]
            vscal_fma(Cres[2], Bv1[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc0-colc3]
            vscal_fma(Cres[3], Bv1[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc0-colc3]
            vscal_fma(Cres[4], Bv1[reg_switch], Av2[reg_switch].x);
            // [row_c+4,col_c0],[row_c,colc_1],[row_c,colc_2],[row_c,colc_3]
            vscal_fma(Cres[5], Bv1[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc0-colc3]
            vscal_fma(Cres[6], Bv1[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc0-colc3]
            vscal_fma(Cres[7], Bv1[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc0-colc3]
            vscal_fma(Cres[8], Bv2[reg_switch], Av1[reg_switch].x);

            // [row_c,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[9], Bv2[reg_switch],
                      Av1[reg_switch].y); // [rowc+1,colc64-colc67]
            vscal_fma(Cres[10], Bv2[reg_switch],
                      Av1[reg_switch].z); // [rowc+2,colc64-colc67]
            vscal_fma(Cres[11], Bv2[reg_switch],
                      Av1[reg_switch].w); // [rowc+3,colc64-colc67]
            vscal_fma(
                Cres[12], Bv2[reg_switch],
                Av2[reg_switch]
                    .x); // [row_c+4,col_c64],[row_c,colc_65],[row_c,colc_66],[row_c,colc_67]
            vscal_fma(Cres[13], Bv2[reg_switch],
                      Av2[reg_switch].y); // [rowc+5,colc64-colc67]
            vscal_fma(Cres[14], Bv2[reg_switch],
                      Av2[reg_switch].z); // [rowc+6,colc64-colc67]
            vscal_fma(Cres[15], Bv2[reg_switch],
                      Av2[reg_switch].w); // [rowc+7,colc64--colc67]
            reg_switch ^= 1;
        }

    // 上面在global_k上的主循环实际上少迭代一次,因为预取相关的问题,对全局内存不能越界
    // load Ctile and accumulate the Cres
    vload(Cv[0], &C(row_c, col_c));
    vload(Cv[1], &C(row_c + 1, col_c));
    vload(Cv[2], &C(row_c + 2, col_c));
    vload(Cv[3], &C(row_c + 3, col_c));
    vload(Cv[4], &C(row_c + 4, col_c));
    vload(Cv[5], &C(row_c + 5, col_c));
    vload(Cv[6], &C(row_c + 6, col_c));
    vload(Cv[7], &C(row_c + 7, col_c));
    vload(Cv[8], &C(row_c, col_c + 64));
    vload(Cv[9], &C(row_c + 1, col_c + 64));
    vload(Cv[10], &C(row_c + 2, col_c + 64));
    vload(Cv[11], &C(row_c + 3, col_c + 64));
    vload(Cv[12], &C(row_c + 4, col_c + 64));
    vload(Cv[13], &C(row_c + 5, col_c + 64));
    vload(Cv[14], &C(row_c + 6, col_c + 64));
    vload(Cv[15], &C(row_c + 7, col_c + 64));

    simd_axpby(Cres[0], alpha, Cres[0], beta, Cv[0]);
    simd_axpby(Cres[1], alpha, Cres[1], beta, Cv[1]);
    simd_axpby(Cres[2], alpha, Cres[2], beta, Cv[2]);
    simd_axpby(Cres[3], alpha, Cres[3], beta, Cv[3]);

    simd_axpby(Cres[4], alpha, Cres[4], beta, Cv[4]);
    simd_axpby(Cres[5], alpha, Cres[5], beta, Cv[5]);
    simd_axpby(Cres[6], alpha, Cres[6], beta, Cv[6]);
    simd_axpby(Cres[7], alpha, Cres[7], beta, Cv[7]);

    simd_axpby(Cres[8], alpha, Cres[8], beta, Cv[8]);
    simd_axpby(Cres[9], alpha, Cres[9], beta, Cv[9]);
    simd_axpby(Cres[10], alpha, Cres[10], beta, Cv[10]);
    simd_axpby(Cres[11], alpha, Cres[11], beta, Cv[11]);

    simd_axpby(Cres[12], alpha, Cres[12], beta, Cv[12]);
    simd_axpby(Cres[13], alpha, Cres[13], beta, Cv[13]);
    simd_axpby(Cres[14], alpha, Cres[14], beta, Cv[14]);
    simd_axpby(Cres[15], alpha, Cres[15], beta, Cv[15]);

    vstore(&C(row_c, col_c), Cres[0]);
    vstore(&C(row_c + 1, col_c), Cres[1]);
    vstore(&C(row_c + 2, col_c), Cres[2]);
    vstore(&C(row_c + 3, col_c), Cres[3]);
    vstore(&C(row_c + 4, col_c), Cres[4]);
    vstore(&C(row_c + 5, col_c), Cres[5]);
    vstore(&C(row_c + 6, col_c), Cres[6]);
    vstore(&C(row_c + 7, col_c), Cres[7]);

    vstore(&C(row_c, col_c + 64), Cres[8]);
    vstore(&C(row_c + 1, col_c + 64), Cres[9]);
    vstore(&C(row_c + 2, col_c + 64), Cres[10]);
    vstore(&C(row_c + 3, col_c + 64), Cres[11]);
    vstore(&C(row_c + 4, col_c + 64), Cres[12]);
    vstore(&C(row_c + 5, col_c + 64), Cres[13]);
    vstore(&C(row_c + 6, col_c + 64), Cres[14]);
    vstore(&C(row_c + 7, col_c + 64), Cres[15]);
}
// 3750.55 3591.11 3529.28 3460.77
template <typename T>
bool verify_res(size_t m, size_t n, const thrust::device_vector<T> &ref_data,
                const thrust::device_vector<T> &res_data,
                T abs_error = T(1e-2))
{
    thrust::host_vector<T> href_data = ref_data;
    thrust::host_vector<T> hres_data = res_data;

    T max_error = std::numeric_limits<T>::lowest();
    int num_errors = 0;
    for (size_t i = 0; i < m; i++)
    {
        for (size_t j = 0; j < n; j++)
        {
            auto tmp_error = std::abs(hres_data[i * n + j] - href_data[i * n + j]);
            // std::cout<<tmp_error<<"\n";
            if (tmp_error > abs_error)
            {
                num_errors++;
                max_error = max_error < tmp_error ? tmp_error : max_error;
            }
        }
    }
    std::cout << "num_error: " << num_errors << " max error= " << max_error
              << " \n";
    return num_errors == 0;
}

template <typename T>
void host_gemm(int M, int N, int K, std::vector<T> &A, std::vector<T> &B,
               std::vector<T> &C, T alpha, T beta)
{
    for (int m = 0; m < M; m++)
    {
        for (int n = 0; n < N; n++)
        {
            T accum = 0;
            for (int k = 0; k < K; k++)
            {
                accum += A[m * K + k] * B[k * N + n];
            }
            C[m * N + n] = alpha * accum + beta * C[m * N + n];
        }
    }
}

// print an M-by-N array
template <typename T>
void print(size_t m, size_t n, thrust::device_vector<T> &d_data)
{
    thrust::host_vector<T> h_data = d_data;

    for (size_t i = 0; i < m; i++)
    {
        for (size_t j = 0; j < n; j++)
            std::cout << std::setw(1) << h_data[i * n + j] << " ";
        std::cout << "\n";
    }
}

int main(int argc, char **argv)
{
    const int M = 6144;
    const int N = 6144;
    const int K = 6144;
    constexpr int Ms = 128;
    constexpr int Ns = 128;
    constexpr int Ks = 8;
    using Element = float;
    std::vector<Element> hA(M * K);
    std::vector<Element> hB(K * N);
    std::vector<Element> hC(M * N);
    std::random_device rd;  // 将用于获得随机数引擎的种子
    std::mt19937 gen(rd()); // 以 rd() 播种的标准 mersenne_twister_engine
    std::uniform_real_distribution<Element> dis(1, 10);
    std::generate(hA.begin(), hA.end(), [&rd, &gen, &dis]()
                  { return dis(gen); });
    std::generate(hB.begin(), hB.end(), [&rd, &gen, &dis]()
                  { return dis(gen); });
    thrust::device_vector<Element> dA = hA;
    thrust::device_vector<Element> dB = hB;
    thrust::device_vector<Element> dC(M * N);
    thrust::device_vector<Element> drefC(M * N);
    Element *dA_ptr = thrust::raw_pointer_cast(dA.data());
    Element *dB_ptr = thrust::raw_pointer_cast(dB.data());
    Element *dC_ptr = thrust::raw_pointer_cast(dC.data());
    Element *dCref_ptr = thrust::raw_pointer_cast(drefC.data());

    cublasHandle_t handle;
    cublasSafeCall(cublasCreate(&handle));
    float alpha = 1.;
    float beta = 0.;
    cublasSafeCall(cublasSgemm_v2(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K,
                                  &alpha, dB_ptr, N, dA_ptr, K, &beta, dCref_ptr,
                                  N));

    const dim3 block(256);
    const dim3 grid((M + 127) / 128, (N + 127) / 128);
    gemm_kernel<<<grid, block>>>(M, N, K, alpha, dA_ptr, dB_ptr, beta, dC_ptr);

    gemm_kernel<<<grid, block>>>(M, N, K, alpha, dA_ptr, dB_ptr, beta, dC_ptr);
    std::cout << cudaGetErrorString(cudaGetLastError()) << "\n";
    verify_res(M, N, dC, drefC);

    // host_gemm(M, N, K, hA, hB, hC, alpha, beta);
    // verify_res(M, N, thrust::device_vector<Element>(hC), drefC);
    // std::cout << "hc verify dc \n";
    // verify_res(M, N, thrust::device_vector<Element>(hC), dC);
    // print(M, N, dC);

    // print(M, N, drefC);
}

//
//
// nvcc -arch=sm_75 -O3   ./gemm.cu -o gemmtest -lcublas
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 我是黑夜里大雨纷飞的人啊 1 “又到一年六月,有人笑有人哭,有人欢乐有人忧愁,有人惊喜有人失落,有的觉得收获满满有...
    陌忘宇阅读 8,614评论 28 53
  • 首先介绍下自己的背景: 我11年左右入市到现在,也差不多有4年时间,看过一些关于股票投资的书籍,对于巴菲特等股神的...
    瞎投资阅读 5,802评论 3 8
  • ![Flask](...
    极客学院Wiki阅读 7,420评论 0 3
  • 不知不觉易趣客已经在路上走了快一年了,感觉也该让更多朋友认识知道易趣客,所以就谢了这篇简介,已做创业记事。 易趣客...
    Physher阅读 3,452评论 1 2