Author: Aart Bik Date: 2021-01-20T14:30:13-08:00 New Revision: 5959c28f24856f3d4a1db6b4743c66bdc6dcd735
URL: https://github.com/llvm/llvm-project/commit/5959c28f24856f3d4a1db6b4743c66bdc6dcd735 DIFF: https://github.com/llvm/llvm-project/commit/5959c28f24856f3d4a1db6b4743c66bdc6dcd735.diff LOG: [mlir][sparse] add asserts on reading in tensor data Rationale: Since I made the argument that metadata helps with extra verification checks, I better actually do that ;-) Reviewed By: penpornk Differential Revision: https://reviews.llvm.org/D95072 Added: Modified: mlir/lib/ExecutionEngine/SparseUtils.cpp Removed: ################################################################################ diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp index 376b989975b5..d1962661fe79 100644 --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -48,9 +48,9 @@ namespace { /// and a rank-5 tensor element like /// ({i,j,k,l,m}, a[i,j,k,l,m]) struct Element { - Element(const std::vector<int64_t> &ind, double val) + Element(const std::vector<uint64_t> &ind, double val) : indices(ind), value(val){}; - std::vector<int64_t> indices; + std::vector<uint64_t> indices; double value; }; @@ -61,9 +61,15 @@ struct Element { /// formats require the elements to appear in lexicographic index order). struct SparseTensor { public: - SparseTensor(int64_t capacity) : pos(0) { elements.reserve(capacity); } + SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity) + : sizes(szs), pos(0) { + elements.reserve(capacity); + } // Add element as indices and value. - void add(const std::vector<int64_t> &ind, double val) { + void add(const std::vector<uint64_t> &ind, double val) { + assert(sizes.size() == ind.size()); + for (int64_t r = 0, rank = sizes.size(); r < rank; r++) + assert(ind[r] < sizes[r]); // within bounds elements.emplace_back(Element(ind, val)); } // Sort elements lexicographically by index. @@ -82,6 +88,8 @@ struct SparseTensor { } return false; } + + std::vector<uint64_t> sizes; // per-rank dimension sizes std::vector<Element> elements; uint64_t pos; }; @@ -225,20 +233,24 @@ extern "C" void *openTensorC(char *filename, uint64_t *idata) { fprintf(stderr, "Unknown format %s\n", filename); exit(1); } - // Read all nonzero elements. + // Prepare sparse tensor object with per-rank dimension sizes + // and the number of nonzeros as initial capacity. uint64_t rank = idata[0]; uint64_t nnz = idata[1]; - SparseTensor *tensor = new SparseTensor(nnz); - std::vector<int64_t> indices(rank); - double value; + std::vector<uint64_t> indices(rank); + for (uint64_t r = 0; r < rank; r++) + indices[r] = idata[2 + r]; + SparseTensor *tensor = new SparseTensor(indices, nnz); + // Read all nonzero elements. for (uint64_t k = 0; k < nnz; k++) { for (uint64_t r = 0; r < rank; r++) { - if (fscanf(file, "%" PRId64, &indices[r]) != 1) { + if (fscanf(file, "%" PRIu64, &indices[r]) != 1) { fprintf(stderr, "Cannot find next index in %s\n", filename); exit(1); } indices[r]--; // 0-based index } + double value; if (fscanf(file, "%lg\n", &value) != 1) { fprintf(stderr, "Cannot find next value in %s\n", filename); exit(1); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits