Skip to content

Instantly share code, notes, and snippets.

@4DA
Created June 11, 2024 08:40
Show Gist options
  • Save 4DA/05c6bde7232d400ddea0ccd4eced916c to your computer and use it in GitHub Desktop.
Save 4DA/05c6bde7232d400ddea0ccd4eced916c to your computer and use it in GitHub Desktop.
Compute JDS (Jagged Diagonal Storage) matrix in column-major order (ECE408)
// compute JDS (Jagged Diagonal Storage) matrix in column-major order from
// matrix in CSR (Compressed Sparse Row) format
void CSRToJDS(int dim, // number of rows
int valueCount, // number of values in hostCSRData
int CSRColsSz, // number of values in hostCSRCols
int *hostCSRRows, // each value is offset where row starts in hostCSRData
int *hostCSRCols, // each value is corresponding column of value in hostCSRData
float *hostCSRData, // non-zero values of matrix in CSR format
int **hostJDSRowPerm, // permutations of JDS rows
int **hostJDSRows, // length of each JDS row
int **hostJDSColStart, // each value if offset where column starts in hostJDSData
int **hostJDSCols, // each value is corresponding column of value in hostJDSData
float **hostJDSData // non-zero values of matrix in JDS format
)
{
*hostJDSRowPerm = (int *) malloc(dim * sizeof(int));
*hostJDSRows = (int *) malloc(dim * sizeof(int));
*hostJDSColStart = (int *) malloc(valueCount * sizeof(int));
*hostJDSCols = (int *) malloc(CSRColsSz * sizeof(int));
*hostJDSData = (float *) malloc(valueCount * sizeof(float));
memset(*hostJDSData, 0, valueCount * sizeof(float));
struct RowDesc {
int len; // length of row
int id; // original number of row
int ptr; // row offset in CSR data
};
int ncols = -1;
for (int i = 0; i < CSRColsSz; i++) {
ncols = std::max(hostCSRCols[i], ncols);
}
ncols++;
std::vector<RowDesc> rowDesc;
rowDesc.reserve(dim);
for (int i = 0; i < dim; i++) {
rowDesc.push_back({hostCSRRows[i+1] - hostCSRRows[i], i, hostCSRRows[i]});
}
// sort rows by their non-zero length in descending order
std::sort(rowDesc.begin(), rowDesc.end(), [](const RowDesc &r1, const RowDesc &r2) -> bool {
return (r1.len > r2.len);
});
for (int i = 0; i < rowDesc.size(); i++) {
(*hostJDSRowPerm)[i] = rowDesc[i].id;
(*hostJDSRows)[i] = rowDesc[i].len;
// printf("row len: %d, id: %d, ptr: %d\n",
// rowDesc[i].len,
// rowDesc[i].id,
// rowDesc[i].ptr);
}
// DEBUG: dense matrix
// float *denseM = (float *) malloc(dim * ncols * sizeof(float));
// memset(denseM, 0, dim * ncols * sizeof(float));
int oft = 0;
// printf("jds values (col major): \n");
for (int col = 0; col < rowDesc[0].len; col++) {
// printf("col %d | start: %d\n", col, oft);
(*hostJDSColStart)[col] = oft;
for (int row = 0; row < rowDesc.size(); row++) {
if (col >= rowDesc[row].len) {break;}
(*hostJDSCols)[oft] = hostCSRCols[rowDesc[row].ptr + col];
(*hostJDSData)[oft] = hostCSRData[rowDesc[row].ptr + col];
// printf("%.0f ", hostCSRData[rowDesc[row].ptr + col]);
// DEBUG: dense matrix
// int rn = rowDesc[row].id;
// int cn = hostCSRCols[rowDesc[row].ptr + col];
// denseM[ncols * rn + cn] = hostCSRData[rowDesc[row].ptr + col];
oft++;
}
// printf("\n");
}
// printf("\n");
// DEBUG: dense matrix
// printf("dense mat:\n");
// for (int i = 0; i < dim; i++) {
// printf("[ ");
// for (int j = 0; j < ncols; j++) {
// printf("%.0f ", denseM[ncols * i + j]);
// }
// printf("]; ");
// printf("\n");
// }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment