1#ifndef _BLOCK_SPGEMM_H_
2#define _BLOCK_SPGEMM_H_
21 std::vector<std::vector<SpParMat<IT, NTA, DERA>>> A_blocks_;
22 std::vector<std::vector<SpParMat<IT, NTB, DERB>>> B_blocks_;
23 int br_, bc_, bi_, cur_block_;
37 br_(br), bc_(bc), bi_(bi), cur_block_(0)
39 A_blocks_ =
A.BlockSplit(br_, bi_);
40 B_blocks_ =
B.BlockSplit(bi_, bc_);
55 int rbid = cur_block_ / bc_;
56 int cbid = cur_block_ % bc_;
61 roffset = (std::min(
static_cast<IT>(rbid), r)*(bs+1)) +
62 ((rbid < r ? 0 : rbid-r)*bs);
67 coffset = (std::min(
static_cast<IT>(cbid), r)*(bs+1)) +
68 ((cbid < r ? 0 : cbid-r)*bs);
71 return Mult_AnXBn_DoubleBuff<SR, NTC, DERC>
72 (A_blocks_[rbid][0], B_blocks_[0][cbid],
false,
false);
80 return cur_block_ < br_*bc_;
95 roffset = (std::min(
static_cast<IT>(rbid), r)*(bs+1)) +
96 ((rbid < r ? 0 : rbid-r)*bs);
101 coffset = (std::min(
static_cast<IT>(cbid), r)*(bs+1)) +
102 ((cbid < r ? 0 : cbid-r)*bs);
105 return Mult_AnXBn_DoubleBuff<SR, NTC, DERC>
106 (A_blocks_[rbid][0], B_blocks_[0][cbid],
false,
false);
122 int nblocks = (is_row ? br_ : bc_);
123 std::vector<IT> offsets(nblocks+1);
124 for (
int b = 0; b < nblocks; ++b)
125 offsets[b] = (std::min(
static_cast<IT>(b), r)*(bs+1)) +
126 ((b < r ? 0 : b-r)*bs);
127 offsets[nblocks] = (is_row ? nr_ : nc_);
SelectMaxSRing< bool, int64_t > SR
SpParMat< IT, NTC, DERC > getBlockId(int rbid, int cbid, IT &roffset, IT &coffset)
std::vector< IT > getBlockOffsets(bool is_row)
BlockSpGEMM(SpParMat< IT, NTA, DERA > &A, SpParMat< IT, NTB, DERB > &B, int br, int bc, int bi=1)
SpParMat< IT, NTC, DERC > getNextBlock(IT &roffset, IT &coffset)