hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From takuti <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #121: [HIVEMALL-151] Support Matrix conversi...
Date Thu, 12 Oct 2017 06:16:03 GMT
Github user takuti commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/121#discussion_r144199528
  
    --- Diff: core/src/main/java/hivemall/math/matrix/MatrixUtils.java ---
    @@ -70,4 +77,259 @@ public void apply(int i, int value) {
             return which.getValue();
         }
     
    +    /**
    +     * @param data non-zero entries
    +     */
    +    @Nonnull
    +    public static CSRMatrix coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols,
    +            @Nonnull final double[] data, @Nonnegative final int numRows,
    +            @Nonnegative final int numCols, final boolean sortColumns) {
    +        final int nnz = data.length;
    +        Preconditions.checkArgument(rows.length == nnz);
    +        Preconditions.checkArgument(cols.length == nnz);
    +
    +        final int[] rowPointers = new int[numRows + 1];
    +        final int[] colIndicies = new int[nnz];
    +        final double[] values = new double[nnz];
    +
    +        coo2csr(rows, cols, data, rowPointers, colIndicies, values, numRows, numCols,
nnz);
    +
    +        if (sortColumns) {
    +            sortIndicies(rowPointers, colIndicies, values);
    +        }
    +        return new CSRMatrix(rowPointers, colIndicies, values, numCols);
    +    }
    +
    +    /**
    +     * @param data non-zero entries
    +     */
    +    @Nonnull
    +    public static CSRFloatMatrix coo2csr(@Nonnull final int[] rows, @Nonnull final int[]
cols,
    +            @Nonnull final float[] data, @Nonnegative final int numRows,
    +            @Nonnegative final int numCols, final boolean sortColumns) {
    +        final int nnz = data.length;
    +        Preconditions.checkArgument(rows.length == nnz);
    +        Preconditions.checkArgument(cols.length == nnz);
    +
    +        final int[] rowPointers = new int[numRows + 1];
    +        final int[] colIndicies = new int[nnz];
    +        final float[] values = new float[nnz];
    +
    +        coo2csr(rows, cols, data, rowPointers, colIndicies, values, numRows, numCols,
nnz);
    +
    +        if (sortColumns) {
    +            sortIndicies(rowPointers, colIndicies, values);
    +        }
    +        return new CSRFloatMatrix(rowPointers, colIndicies, values, numCols);
    +    }
    +
    +    @Nonnull
    +    public static CSCMatrix coo2csc(@Nonnull final int[] rows, @Nonnull final int[] cols,
    +            @Nonnull final double[] data, @Nonnegative final int numRows,
    +            @Nonnegative final int numCols, final boolean sortRows) {
    +        final int nnz = data.length;
    +        Preconditions.checkArgument(rows.length == nnz);
    +        Preconditions.checkArgument(cols.length == nnz);
    +
    +        final int[] columnPointers = new int[numCols + 1];
    +        final int[] rowIndicies = new int[nnz];
    +        final double[] values = new double[nnz];
    +
    +        coo2csr(cols, rows, data, columnPointers, rowIndicies, values, numCols, numRows,
nnz);
    +
    +        if (sortRows) {
    +            sortIndicies(columnPointers, rowIndicies, values);
    +        }
    +        return new CSCMatrix(columnPointers, rowIndicies, values, numRows, numCols);
    +    }
    +
    +    @Nonnull
    +    public static CSCFloatMatrix coo2csc(@Nonnull final int[] rows, @Nonnull final int[]
cols,
    +            @Nonnull final float[] data, @Nonnegative final int numRows,
    +            @Nonnegative final int numCols, final boolean sortRows) {
    +        final int nnz = data.length;
    +        Preconditions.checkArgument(rows.length == nnz);
    +        Preconditions.checkArgument(cols.length == nnz);
    +
    +        final int[] columnPointers = new int[numCols + 1];
    +        final int[] rowIndicies = new int[nnz];
    +        final float[] values = new float[nnz];
    +
    +        coo2csr(cols, rows, data, columnPointers, rowIndicies, values, numCols, numRows,
nnz);
    +
    +        if (sortRows) {
    +            sortIndicies(columnPointers, rowIndicies, values);
    +        }
    +
    +        return new CSCFloatMatrix(columnPointers, rowIndicies, values, numRows, numCols);
    +    }
    +
    +    private static void coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols,
    +            @Nonnull final double[] data, @Nonnull final int[] rowPointers,
    +            @Nonnull final int[] colIndicies, @Nonnull final double[] values,
    +            @Nonnegative final int numRows, @Nonnegative final int numCols, final int
nnz) {
    +        // compute nnz per for each row to get rowPointers
    +        for (int n = 0; n < nnz; n++) {
    +            rowPointers[rows[n]]++;
    +        }
    +        for (int i = 0, sum = 0; i < numRows; i++) {
    +            int curr = rowPointers[i];
    +            rowPointers[i] = sum;
    +            sum += curr;
    +        }
    +        rowPointers[numRows] = nnz;
    +
    +        // copy cols, data to colIndicies, csrValues
    +        for (int n = 0; n < nnz; n++) {
    +            int row = rows[n];
    +            int dst = rowPointers[row];
    +
    +            colIndicies[dst] = cols[n];
    +            values[dst] = data[n];
    +
    +            rowPointers[row]++;
    +        }
    +
    +        for (int i = 0, last = 0; i <= numRows; i++) {
    +            int tmp = rowPointers[i];
    +            rowPointers[i] = last;
    +            last = tmp;
    +        }
    +    }
    +
    +    private static void coo2csr(@Nonnull final int[] rows, @Nonnull final int[] cols,
    +            @Nonnull final float[] data, @Nonnull final int[] rowPointers,
    +            @Nonnull final int[] colIndicies, @Nonnull final float[] values,
    +            @Nonnegative final int numRows, @Nonnegative final int numCols, final int
nnz) {
    +        // compute nnz per for each row to get rowPointers
    +        for (int n = 0; n < nnz; n++) {
    +            rowPointers[rows[n]]++;
    +        }
    +        for (int i = 0, sum = 0; i < numRows; i++) {
    +            int curr = rowPointers[i];
    +            rowPointers[i] = sum;
    +            sum += curr;
    +        }
    +        rowPointers[numRows] = nnz;
    +
    +        // copy cols, data to colIndicies, csrValues
    +        for (int n = 0; n < nnz; n++) {
    +            int row = rows[n];
    +            int dst = rowPointers[row];
    +
    +            colIndicies[dst] = cols[n];
    +            values[dst] = data[n];
    +
    +            rowPointers[row]++;
    +        }
    +
    +        for (int i = 0, last = 0; i <= numRows; i++) {
    +            int tmp = rowPointers[i];
    +            rowPointers[i] = last;
    +            last = tmp;
    +        }
    +    }
    +
    +    private static void sortIndicies(@Nonnull final int[] rowPointers,
    +            @Nonnull final int[] colIndicies, @Nonnull final double[] values) {
    +        final int numRows = rowPointers.length - 1;
    +        if (numRows <= 1) {
    +            return;
    +        }
    +
    +        for (int i = 0; i < numRows; i++) {
    +            final int rowStart = rowPointers[i];
    +            final int rowEnd = rowPointers[i + 1];
    +
    +            final int numCols = rowEnd - rowStart;
    +            if (numCols == 0) {
    +                continue;
    +            } else if (numCols < 0) {
    +                throw new IllegalArgumentException(
    +                    "numCols SHOULD be greater than zero. numCols = rowEnd - rowStart
= " + rowEnd
    +                            + " - " + rowStart + " = " + numCols + " at i=" + i);
    +            }
    +
    +            final IntDoublePair[] pairs = new IntDoublePair[numCols];
    --- End diff --
    
    Why don't you use existing `hivemall.utils.struct.Pair` instead of newly introduced `IntDoublePair`
(and `IntFloatPair`) instance?
    
    We frequently like to use pair/tuple-ish data structure, so I feel using the same interface
`struct.Pair` as much as we can is a better idea.
    
    ```java
    final List<Pair<Integer, Double>> pairs = new ArrayList<Pair<Integer,
Double>>();
    for (int jj = rowStart; jj < rowEnd; jj++) {
        pairs.add(Pair.of(colIndicies[jj], values[jj]));
    }
    
    Collections.sort(pairs, new Comparator<Pair<Integer, Double>>() {
        @Override
        public int compare(Pair<Integer, Double> x, Pair<Integer, Double> y) {
            return Integer.compare(x.getKey(), y.getKey());
        }
    });
    
    for (int jj = rowStart, n = 0; jj < rowEnd; jj++, n++) {
        Pair<Integer, Double> tmp = pairs.get(n);
        colIndicies[jj] = tmp.getKey();
        values[jj] = tmp.getValue();
    }
    ```


---

Mime
View raw message