ice 发表于 2013-10-10 15:52
请给出您的完整代码以便分析。
__global__ void sgemmNN(const float* A,int lda,const float* B,int ldb,float* C,int ldc,int k)
{
A+=blockIdx.x*64+threadIdx.x+threadIdx.y*16;
B+=threadIdx.x+__mul24((__mul24(blockIdx.y,16)+threadIdx.y),ldb);
C+=blockIdx.x*64+threadIdx.x+__mul24((threadIdx.y+__mul24(blockIdx.y,ldc)),16);
int num=0;
int id=blockIdx.y*64+threadIdx.y*16+threadIdx.x;
__shared__ float bs[16][17];
float c[16]={0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
const float* Blast=B+k;
do
{
//num++;
#pragma unroll
for(int i=0;i<16;i+=4)
bs[threadIdx.x][threadIdx.y+i]=B[i*ldb];
B+=16;
__syncthreads();
#pragma unroll
for(int i=0;i<16;i++,A+=lda)
{
c[0]+=A[0]*bs[0];c[1]+=A[0]*bs[1];
c[2]+=A[0]*bs[2];c[3]+=A[0]*bs[3];
c[4]+=A[0]*bs[4];c[5]+=A[0]*bs[5];
c[6]+=A[0]*bs[6];c[7]+=A[0]*bs[7];
c[8]+=A[0]*bs[8];c[9]+=A[0]*bs[9];
c[10]+=A[0]*bs[10];c[11]+=A[0]*bs[11];
c[12]+=A[0]*bs[12];c[13]+=A[0]*bs[13];
c[14]+=A[0]*bs[14];c[15]+=A[0]*bs[15];
}
__syncthreads();
}while(B<Blast);
//pass[id]=num;
#pragma unroll
for(int i=0;i<16;i++)
{
*(C+i*ldc)=c;
}
} |