diff --git a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java index cc636c6a91d..f997db6503a 100644 --- a/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java +++ b/src/main/java/org/apache/sysds/hops/estim/EstimatorLayeredGraph.java @@ -33,6 +33,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * This estimator implements an approach based on a so-called layered graph, @@ -43,7 +44,7 @@ */ public class EstimatorLayeredGraph extends SparsityEstimator { - private static final int ROUNDS = 32; + private static final int ROUNDS = 512; private final int _rounds; public EstimatorLayeredGraph() { @@ -57,21 +58,47 @@ public EstimatorLayeredGraph(int rounds) { @Override public DataCharacteristics estim(MMNode root) { List leafs = getMatrices(root, new ArrayList<>()); - long nnz = new LayeredGraph(leafs, _rounds).estimateNnz(); + List ops = getOps(root, new ArrayList<>()); + List LGs = new ArrayList<>(); + LayeredGraph ret = traverse(root); + long nnz = ret.estimateNnz(); return root.setDataCharacteristics(new MatrixCharacteristics( - leafs.get(0).getNumRows(), leafs.get(leafs.size()-1).getNumColumns(), nnz)); + ret._nodes.get(0).length, ret._nodes.get(ret._nodes.size() - 1).length, nnz)); + } + + public LayeredGraph traverse(MMNode node) { + if(node.getLeft() == null || node.getRight() == null) return null; + LayeredGraph retL = traverse(node.getLeft()); + LayeredGraph retR = traverse(node.getRight()); + LayeredGraph ret, left, right; + + left = (node.getLeft().getData() == null) + ? retL : new LayeredGraph(node.getLeft().getData(), _rounds); + right = (node.getRight().getData() == null) + ? retR : new LayeredGraph(node.getRight().getData(), _rounds); + + ret = estimInternal(left, right, node.getOp()); + + return ret; } @Override public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) { if( op == OpCode.MM ) return estim(m1, m2); - throw new NotImplementedException(); + LayeredGraph lg1 = new LayeredGraph(m1, _rounds); + LayeredGraph lg2 = new LayeredGraph(m2, _rounds); + LayeredGraph output = estimInternal(lg1, lg2, op); + return OptimizerUtils.getSparsity( + output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz()); } @Override public double estim(MatrixBlock m, OpCode op) { - throw new NotImplementedException(); + LayeredGraph lg1 = new LayeredGraph(m, _rounds); + LayeredGraph output = estimInternal(lg1, null, op); + return OptimizerUtils.getSparsity( + output._nodes.get(0).length, output._nodes.get(output._nodes.size() - 1).length, output.estimateNnz()); } @Override @@ -80,6 +107,23 @@ public double estim(MatrixBlock m1, MatrixBlock m2) { return OptimizerUtils.getSparsity( m1.getNumRows(), m2.getNumColumns(), graph.estimateNnz()); } + + private static LayeredGraph estimInternal(LayeredGraph lg1, LayeredGraph lg2, OpCode op) { + switch(op) { + case MM: return lg1.matMult(lg2); + case MULT: return lg1.and(lg2); + case PLUS: return lg1.or(lg2); + case RBIND: return lg1.rbind(lg2); + case CBIND: return lg1.cbind(lg2); +// case NEQZERO: +// case EQZERO: + case TRANS: return lg1.transpose(); + case DIAG: return lg1.diag(); +// case RESHAPE: + default: + throw new NotImplementedException(); + } + } private List getMatrices(MMNode node, List leafs) { //NOTE: this extraction is only correct and efficient for chains, no DAGs @@ -92,6 +136,18 @@ private List getMatrices(MMNode node, List leafs) { return leafs; } + private List getOps(MMNode node, List ops) { + //NOTE: this extraction is only correct and efficient for chains, no DAGs + if(node.isLeaf()) { + } + else { + getOps(node.getLeft(), ops); + getOps(node.getRight(), ops); + ops.add(node.getOp()); + } + return ops; + } + public static class LayeredGraph { private final List _nodes; //nodes partitioned by graph level private final int _rounds; //length of propagated r-vectors @@ -101,6 +157,12 @@ public LayeredGraph(List chain, int r) { _rounds = r; chain.forEach(i -> buildNext(i)); } + + public LayeredGraph(MatrixBlock m, int r) { + _nodes = new ArrayList<>(); + _rounds = r; + buildNext(m); + } public void buildNext(MatrixBlock mb) { if( mb.isEmpty() ) @@ -168,7 +230,267 @@ private static double calcNNZ(double[] inpvec, int rounds) { return (inpvec != null && inpvec.length > 0) ? (rounds - 1) / Arrays.stream(inpvec).sum() : 0; } - + + public LayeredGraph rbind(LayeredGraph lg) { + LayeredGraph ret = new LayeredGraph(List.of(), _rounds); + + Node[] rows = new Node[_nodes.get(0).length + lg._nodes.get(0).length]; + Node[] columns = _nodes.get(1).clone(); + + System.arraycopy(_nodes.get(0), 0, rows, 0, _nodes.get(0).length); + + for (int i = _nodes.get(0).length; i < rows.length; i++) + rows[i] = new Node(); + + for(int i = 0; i < lg._nodes.get(0).length; i++) { + for(int j = 0; j < columns.length; j++) { + List edges = lg._nodes.get(1)[j].getInput(); + if(edges.contains(lg._nodes.get(0)[i])) { + columns[j].addInput(rows[i + _nodes.get(0).length]); + } + } + } + ret._nodes.add(rows); + ret._nodes.add(columns); + return ret; + } + + public LayeredGraph cbind(LayeredGraph lg) { + LayeredGraph ret = new LayeredGraph(List.of(), _rounds); + int colLength = _nodes.get(1).length + lg._nodes.get(1).length; + + Node[] rows = _nodes.get(0).clone(); + Node[] columns = new Node[colLength]; + + System.arraycopy(_nodes.get(1), 0, columns, 0, _nodes.get(1).length); + + for (int i = _nodes.get(1).length; i < columns.length; i++) + columns[i] = new Node(); + + for(int i = 0; i < rows.length; i++) { + for(int j = 0; j < lg._nodes.get(1).length; j++) { + List edges = lg._nodes.get(1)[j].getInput(); + if(edges.contains(lg._nodes.get(0)[i])) { + columns[j + _nodes.get(1).length].addInput(rows[i]); + } + } + } + ret._nodes.add(rows); + ret._nodes.add(columns); + return ret; + } + + public LayeredGraph matMult(LayeredGraph lg) { + List m = Stream.concat( + this.toMatrixBlockList().stream(), lg.toMatrixBlockList().stream()) + .collect(Collectors.toList()); + return new LayeredGraph(m, _rounds); + } + + public LayeredGraph or(LayeredGraph lg) { + LayeredGraph ret = new LayeredGraph(List.of(), _rounds); + Node[] rows = new Node[_nodes.get(0).length]; + for (int i = 0; i < _nodes.get(0).length; i++) + rows[i] = new Node(); + ret._nodes.add(rows); + + for(int x = 0; x < _nodes.size() - 1; x++) { + int y = x + 1; + rows = ret._nodes.get(x); + Node[] columns = new Node[_nodes.get(y).length]; + for (int i = 0; i < _nodes.get(y).length; i++) + columns[i] = new Node(); + + for(int i = 0; i < _nodes.get(x).length; i++) { + for(int j = 0; j < _nodes.get(y).length; j++) { + List edges1 = _nodes.get(y)[j].getInput(); + List edges2 = lg._nodes.get(y)[j].getInput(); + if(edges1.contains(_nodes.get(x)[i]) || edges2.contains(lg._nodes.get(x)[i])) + { + columns[j].addInput(rows[i]); + } + } + } + ret._nodes.add(columns); + } + return ret; + } + + public LayeredGraph and(LayeredGraph lg) { + LayeredGraph ret = new LayeredGraph(List.of(), _rounds); + Node[] rows = new Node[_nodes.get(0).length]; + for (int i = 0; i < _nodes.get(0).length; i++) + rows[i] = new Node(); + ret._nodes.add(rows); + + for(int x = 0; x < _nodes.size() - 1; x++) { + int y = x + 1; + rows = ret._nodes.get(x); + Node[] columns = new Node[_nodes.get(y).length]; + for (int i = 0; i < _nodes.get(y).length; i++) + columns[i] = new Node(); + + for(int i = 0; i < _nodes.get(x).length; i++) { + for(int j = 0; j < _nodes.get(y).length; j++) { + List edges1 = _nodes.get(y)[j].getInput(); + List edges2 = lg._nodes.get(y)[j].getInput(); + if(edges1.contains(_nodes.get(x)[i]) && edges2.contains(lg._nodes.get(x)[i])) + { + columns[j].addInput(rows[i]); + } + } + } + ret._nodes.add(columns); + } + return ret; + } + + public LayeredGraph transpose() { + LayeredGraph ret = new LayeredGraph(List.of(), _rounds); + Node[] rows = new Node[_nodes.get(_nodes.size() - 1).length]; + for (int i = 0; i < rows.length; i++) + rows[i] = new Node(); + ret._nodes.add(rows); + + for(int x = _nodes.size() - 1; x > 0; x--) { + rows = ret._nodes.get(ret._nodes.size() - 1); + Node[] columnsOld = _nodes.get(x); + Node[] rowsOld = _nodes.get(x - 1); + Node[] columns = new Node[rowsOld.length]; + + for (int i = 0; i < rowsOld.length; i++) + columns[i] = new Node(); + + for(int i = 0; i < rowsOld.length; i++) { + for(int j = 0; j < columnsOld.length; j++) { + List edges = columnsOld[j].getInput(); + if(edges.contains(rowsOld[i])) { + columns[i].addInput(rows[j]); + } + } + } + ret._nodes.add(columns); + } + return ret; + } + + public LayeredGraph diag() { + LayeredGraph ret = new LayeredGraph(List.of(), _rounds); + Node[] rowsOld = _nodes.get(0); + Node[] columnsOld = _nodes.get(1); + + if(_nodes.get(1).length == 1) { + Node[] rows = new Node[rowsOld.length]; + Node[] columns = new Node[rowsOld.length]; + + for (int i = 0; i < rowsOld.length; i++) + rows[i] = new Node(); + for (int i = 0; i < rowsOld.length; i++) + columns[i] = new Node(); + + List edges = columnsOld[0].getInput(); + for(int i = 0; i < rowsOld.length; i++) { + for(int j = 0; j < rowsOld.length; j++) { + if(edges.contains(rowsOld[i]) && i == j) { + columns[j].addInput(rows[i]); + } + } + } + ret._nodes.add(rows); + ret._nodes.add(columns); + return ret; + } + else if(_nodes.get(0).length == 1){ + Node[] rows = new Node[columnsOld.length]; + Node[] columns = new Node[columnsOld.length]; + + for (int i = 0; i < columnsOld.length; i++) + rows[i] = new Node(); + for (int i = 0; i < columnsOld.length; i++) + columns[i] = new Node(); + + for(int i = 0; i < columnsOld.length; i++) { + for(int j = 0; j < columnsOld.length; j++) { + List edges = columnsOld[j].getInput(); + if(edges.contains(rowsOld[0]) && i == j) { + columns[j].addInput(rows[i]); + } + } + } + ret._nodes.add(rows); + ret._nodes.add(columns); + return ret; + } + else { + Node[] rows = new Node[rowsOld.length]; + Node[] columns = new Node[1]; + for (int i = 0; i < rowsOld.length; i++) + rows[i] = new Node(); + for (int i = 0; i < 1; i++) + columns[i] = new Node(); + for(int i = 0; i < rowsOld.length; i++) { + for(int j = 0; j < columnsOld.length; j++) { + List edges = columnsOld[j].getInput(); + if(edges.contains(rowsOld[i]) && i == j) { + columns[0].addInput(rows[i]); + } + } + } + ret._nodes.add(rows); + ret._nodes.add(columns); + return ret; + } + } + + public MatrixBlock toMatrixBlock() { + List a = new ArrayList<>(); + int rows = _nodes.get(0).length; + int cols = _nodes.get(1).length; + for(int i = 0; i < rows * cols; i++) { + a.add(0.); + } + for(int i = 0; i < rows; i++) { + for(int j = 0; j < cols; j++) { + List edges = _nodes.get(1)[j].getInput(); + if(edges.contains(_nodes.get(0)[i])) { + a.set(i * cols + j, 1. + a.get(i * cols + j)); + } + else { + a.set(i * cols + j, 0.); + } + } + } + double[] arr = a.stream().mapToDouble(d -> d).toArray(); + return new MatrixBlock(rows, cols, arr); + } + + public List toMatrixBlockList() { + List m = new ArrayList<>(); + for(int x = 0; x < _nodes.size() - 1; x++) { + int y = x + 1; + List a = new ArrayList<>(); + int rows = _nodes.get(x).length; + int cols = _nodes.get(y).length; + for(int i = 0; i < rows * cols; i++) { + a.add(0.); + } + for(int i = 0; i < rows; i++) { + for(int j = 0; j < cols; j++) { + List edges = _nodes.get(y)[j].getInput(); + if(edges.contains(_nodes.get(x)[i])) { + a.set(i * cols + j, 1. + a.get(i * cols + j)); + } + else { + a.set(i * cols + j, 0.); + } + } + } + double[] arr = a.stream().mapToDouble(d -> d).toArray(); + m.add(new MatrixBlock(rows, cols, arr)); + } + return m; + } + private static class Node { private List _input = new ArrayList<>(); private double[] _rvect; diff --git a/src/main/java/org/apache/sysds/hops/estim/SparsityEstimator.java b/src/main/java/org/apache/sysds/hops/estim/SparsityEstimator.java index 6c106d7ccfc..edc6f13cafa 100644 --- a/src/main/java/org/apache/sysds/hops/estim/SparsityEstimator.java +++ b/src/main/java/org/apache/sysds/hops/estim/SparsityEstimator.java @@ -89,7 +89,12 @@ public static enum OpCode { protected boolean isExactMetadataOp(OpCode op) { return ArrayUtils.contains(EXACT_META_DATA_OPS, op); } - + + protected boolean isExactMetadataOp(OpCode op, int clen) { + return ArrayUtils.contains(EXACT_META_DATA_OPS, op) + && (op != OpCode.DIAG || clen == 1); + } + protected DataCharacteristics estimExactMetaData(DataCharacteristics dc1, DataCharacteristics dc2, OpCode op) { switch( op ) { case EQZERO: diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java index 1e592be2387..d58e95553cf 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java @@ -24,6 +24,7 @@ import org.apache.sysds.hops.estim.EstimatorBasicWorst; import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.MMNode; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; @@ -102,6 +103,7 @@ public void testMNCCbind() { } //Bitset + @Test public void testBitsetCaserbind() { runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n, sparsity, rbind); } @@ -112,7 +114,7 @@ public void testBitsetCasecbind() { } //Layered Graph - /*@Test + @Test public void testLGCaserbind() { runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, rbind); } @@ -120,7 +122,7 @@ public void testLGCaserbind() { @Test public void testLGCasecbind() { runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, cbind); - }*/ + } private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp, OpCode op) { @@ -157,6 +159,8 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int throw new NotImplementedException(); } //compare estimated and real sparsity - TestUtils.compareScalars(est, m5.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : 1e-2); + TestUtils.compareScalars(est, m5.getSparsity(), + (estim instanceof EstimatorBasicWorst) ? 5e-1 : + (estim instanceof EstimatorLayeredGraph) ? 5e-2 : 1e-2); } } diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java index e36b5f6e0cb..3e7ad24fe86 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java @@ -24,6 +24,7 @@ import org.apache.sysds.hops.estim.EstimatorBasicWorst; import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -111,7 +112,7 @@ public void testBitsetCaserbind() { } //Layered Graph - /*@Test + @Test public void testLGCaserbind() { runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, rbind); } @@ -122,7 +123,7 @@ public void testLGCasecbind() { } //Sample - @Test + /*@Test public void testSampleCaserbind() { runSparsityEstimateTest(new EstimatorSample(), m, k, n, sparsity, rbind); } @@ -159,6 +160,8 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int throw new NotImplementedException(); } //compare estimated and real sparsity - TestUtils.compareScalars(est, m3.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : 1e-2); + TestUtils.compareScalars(est, m3.getSparsity(), + (estim instanceof EstimatorBasicWorst) ? 5e-1 : + (estim instanceof EstimatorLayeredGraph) ? 3e-2 : 1e-2); } } diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java index f008026dc3a..a1b6594a927 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java @@ -23,6 +23,9 @@ import org.apache.sysds.hops.estim.EstimatorBasicAvg; import org.apache.sysds.hops.estim.EstimatorBasicWorst; import org.apache.sysds.hops.estim.EstimatorBitsetMM; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; +import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.MMNode; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; @@ -73,7 +76,7 @@ public void testWorstPlus() { } //DensityMap - /*@Test + @Test public void testDMMult() { runSparsityEstimateTest(new EstimatorDensityMap(), m, n, sparsity, mult); } @@ -92,7 +95,7 @@ public void testMNCMult() { @Test public void testMNCPlus() { runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, n, sparsity, plus); - }*/ + } //Bitset @Test @@ -106,15 +109,15 @@ public void testBitsetPlus() { } //Layered Graph - /*@Test + @Test public void testLGCasemult() { - runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, mult); + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n, sparsity, mult); } @Test public void testLGCaseplus() { - runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, plus); - }*/ + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n, sparsity, plus); + } private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int n, double[] sp, OpCode op) { @@ -148,6 +151,7 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int throw new NotImplementedException(); } //compare estimated and real sparsity - TestUtils.compareScalars(est, m5.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 9e-1 : 1e-2); + TestUtils.compareScalars(est, m5.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 9e-1 : + (estim instanceof EstimatorLayeredGraph) ? 7e-2 : 1e-2); } } diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java index dae0bb4ddc2..f8ddb91bcef 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; @@ -105,17 +106,17 @@ public void testBitsetMult() { public void testBitsetPlus() { runSparsityEstimateTest(new EstimatorBitsetMM(), m, n, sparsity, plus); } - /* + //Layered Graph @Test public void testLGCasemult() { - runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, mult); + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n, sparsity, mult); } @Test public void testLGCaseplus() { - runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, sparsity, plus); - }*/ + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n, sparsity, plus); + } //Sample @Test @@ -153,6 +154,7 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int throw new NotImplementedException(); } //compare estimated and real sparsity - TestUtils.compareScalars(est, m3.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : 5e-3); + TestUtils.compareScalars(est, m3.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : + (estim instanceof EstimatorLayeredGraph) ? 3e-2 : 5e-3); } } diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java index ea34ac14329..d40f84c4fb3 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java @@ -19,11 +19,13 @@ package org.apache.sysds.test.component.estim; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.junit.Test; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.hops.estim.EstimatorBasicAvg; import org.apache.sysds.hops.estim.EstimatorBasicWorst; import org.apache.sysds.hops.estim.EstimatorBitsetMM; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -39,7 +41,7 @@ public class OpSingleTest extends AutomatedTestBase private final static int k = 300; private final static double sparsity = 0.2; // private final static OpCode eqzero = OpCode.EQZERO; -// private final static OpCode diag = OpCode.DIAG; + private final static OpCode diag = OpCode.DIAG; private final static OpCode neqzero = OpCode.NEQZERO; private final static OpCode trans = OpCode.TRANS; private final static OpCode reshape = OpCode.RESHAPE; @@ -57,7 +59,7 @@ public void setUp() { // @Test // public void testAvgDiag() { -// runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, sparsity, diag); +// runSparsityEstimateTest(new EstimatorBasicAvg(), m, m, sparsity, diag); // } @Test @@ -83,7 +85,7 @@ public void testAvgReshape() { // @Test // public void testWCasediag() { -// runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, sparsity, diag); +// runSparsityEstimateTest(new EstimatorBasicWorst(), m, m, sparsity, diag); // } @Test @@ -109,7 +111,7 @@ public void testWorstReshape() { // // @Test // public void testDMCasediag() { -// runSparsityEstimateTest(new EstimatorDensityMap(), m, k, sparsity, diag); +// runSparsityEstimateTest(new EstimatorDensityMap(), m, m, sparsity, diag); // } // // @Test @@ -135,7 +137,7 @@ public void testWorstReshape() { // // @Test // public void testMNCCasediag() { -// runSparsityEstimateTest(new EstimatorDensityMap(), m, k, sparsity, diag); +// runSparsityEstimateTest(new EstimatorDensityMap(), m, m, sparsity, diag); // } // // @Test @@ -161,7 +163,7 @@ public void testWorstReshape() { // @Test // public void testBitsetCasediag() { -// runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, sparsity, diag); +// runSparsityEstimateTest(new EstimatorBitsetMM(), m, m, sparsity, diag); // } @Test @@ -185,21 +187,26 @@ public void testBitsetReshape() { // runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, sparsity, eqzero); // } // -// @Test -// public void testLGCasediag() { -// runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, sparsity, diag); -// } + @Test + public void testLGCasediagM() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, m, sparsity, diag); + } + + @Test + public void testLGCasediagV() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, 1, sparsity, diag); + } // // @Test // public void testLGCaseneqzero() { // runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, sparsity, neqzero); // } // -// @Test -// public void testLGCasetans() { -// runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, sparsity, trans); -// } -// + @Test + public void testLGCasetrans() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, sparsity, trans); + } + // @Test // public void testLGCasereshape() { // runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, sparsity, reshape); @@ -213,7 +220,7 @@ public void testBitsetReshape() { // // @Test // public void testSampleCasediag() { -// runSparsityEstimateTest(new EstimatorSample(), m, k, sparsity, diag); +// runSparsityEstimateTest(new EstimatorSample(), m, m, sparsity, diag); // } // // @Test @@ -239,6 +246,11 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int case EQZERO: //TODO find out how to do eqzero case DIAG: + m2 = m1.getNumColumns() == 1 + ? LibMatrixReorg.diag(m1, new MatrixBlock(m1.getNumRows(), m1.getNumRows(), false)) + : LibMatrixReorg.diag(m1, new MatrixBlock(m1.getNumRows(), 1, false)); + est = estim.estim(m1, op); + break; case NEQZERO: m2 = m1; est = estim.estim(m1, op); @@ -255,6 +267,8 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int throw new NotImplementedException(); } //compare estimated and real sparsity - TestUtils.compareScalars(est, m2.getSparsity(), (estim instanceof EstimatorBasicWorst) ? 5e-1 : 1e-2); + TestUtils.compareScalars(est, m2.getSparsity(), + (estim instanceof EstimatorBasicWorst) ? 5e-1 : + (estim instanceof EstimatorLayeredGraph) ? 3e-2 : 2e-2); } } diff --git a/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java b/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java index feb1e2ef055..fdc33d878db 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -138,7 +139,17 @@ public void testSampling20Case1() { public void testSampling20Case2() { runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n, case2); } - + + @Test + public void testLayeredGraphCase1() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, case1); + } + + @Test + public void testLayeredGraphCase2() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, case2); + } + private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp) { MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3); MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, "uniform", 3); @@ -147,6 +158,6 @@ private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int //compare estimated and real sparsity double est = estim.estim(m1, m2); - TestUtils.compareScalars(est, m3.getSparsity(), 1e-16); + TestUtils.compareScalars(est, m3.getSparsity(), (estim instanceof EstimatorLayeredGraph) ? 5e-2 : 1e-16); } } diff --git a/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java b/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java index 7167dc22267..20c39b10494 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java @@ -27,6 +27,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -128,6 +129,16 @@ public void testSampling20Case1() { public void testSampling20Case2() { runSparsityEstimateTest(new EstimatorSample(0.2), m, sparsity2); } + + @Test + public void testLayeredGraphCase1() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, sparsity1); + } + + @Test + public void testLayeredGraphCase2() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, sparsity2); + } private static void runSparsityEstimateTest(SparsityEstimator estim, int n, double sp) { MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp, 1, 1, "uniform", 3); diff --git a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java index 25cd99ecaf7..8d0ab8ed8ed 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.MMNode; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.hops.estim.SparsityEstimator.OpCode; @@ -125,6 +126,16 @@ public void testMatrixHistogramExceptCase1() { public void testMatrixHistogramExceptCase2() { runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, n2, case2); } + + @Test + public void testLayeredGraphCase1() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, n2, case1); + } + + @Test + public void testLayeredGraphCase2() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, n2, case2); + } private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, int n2, double[] sp) { MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 1); diff --git a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java index 8fb847af235..2a898f9c39f 100644 --- a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java +++ b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java @@ -25,6 +25,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM; import org.apache.sysds.hops.estim.EstimatorDensityMap; import org.apache.sysds.hops.estim.EstimatorMatrixHistogram; +import org.apache.sysds.hops.estim.EstimatorLayeredGraph; import org.apache.sysds.hops.estim.EstimatorSample; import org.apache.sysds.hops.estim.SparsityEstimator; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -123,7 +124,7 @@ public void testMatrixHistogramExceptCase1() { public void testMatrixHistogramExceptCase2() { runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, case2); } - + @Test public void testSamplingDefCase1() { runSparsityEstimateTest(new EstimatorSample(), m, k, n, case1); @@ -143,6 +144,16 @@ public void testSampling20Case1() { public void testSampling20Case2() { runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n, case2); } + + @Test + public void testLayeredGraphCase1() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, case1); + } + + @Test + public void testLayeredGraphCase2() { + runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, case2); + } private static void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp) { MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, "uniform", 3);