spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-13597][PYSPARK][ML] Python API for GeneralizedLinearRegression
Date Tue, 12 Apr 2016 18:29:13 GMT
Repository: spark
Updated Branches:
  refs/heads/master 101663f1a -> 7f024c474


[SPARK-13597][PYSPARK][ML] Python API for GeneralizedLinearRegression

## What changes were proposed in this pull request?

Python API for GeneralizedLinearRegression
JIRA: https://issues.apache.org/jira/browse/SPARK-13597

## How was this patch tested?

The patch is tested with Python doctest.

Author: Kai Jiang <jiangkai@gmail.com>

Closes #11468 from vectorijk/spark-13597.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f024c47
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f024c47
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f024c47

Branch: refs/heads/master
Commit: 7f024c47441a2f84fcc34a6021b976f036ea24c4
Parents: 101663f
Author: Kai Jiang <jiangkai@gmail.com>
Authored: Tue Apr 12 11:29:12 2016 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Apr 12 11:29:12 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/regression.py | 145 +++++++++++++++++++++++++++++++++++
 1 file changed, 145 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f024c47/python/pyspark/ml/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 1c18df3..bc88f88 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -28,6 +28,7 @@ from pyspark.sql import DataFrame
 __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
            'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
            'GBTRegressor', 'GBTRegressionModel',
+           'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel'
            'IsotonicRegression', 'IsotonicRegressionModel',
            'LinearRegression', 'LinearRegressionModel',
            'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
@@ -1197,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
         return self._call_java("predict", features)
 
 
+@inherit_doc
+class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol,
+                                  HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
+                                  HasSolver, JavaMLWritable, JavaMLReadable):
+    """
+    Generalized Linear Regression.
+
+    Fit a Generalized Linear Model specified by giving a symbolic description of the linear
+    predictor (link function) and a description of the error distribution (family). It supports
+    "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each
family
+    is listed below. The first link function of each family is the default one.
+    - "gaussian" -> "identity", "log", "inverse"
+    - "binomial" -> "logit", "probit", "cloglog"
+    - "poisson"  -> "log", "identity", "sqrt"
+    - "gamma"    -> "inverse", "identity", "log"
+
+    .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
+
+    >>> from pyspark.mllib.linalg import Vectors
+    >>> df = sqlContext.createDataFrame([
+    ...     (1.0, Vectors.dense(0.0, 0.0)),
+    ...     (1.0, Vectors.dense(1.0, 2.0)),
+    ...     (2.0, Vectors.dense(0.0, 0.0)),
+    ...     (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
+    >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
+    >>> model = glr.fit(df)
+    >>> abs(model.transform(df).head().prediction - 1.5) < 0.001
+    True
+    >>> model.coefficients
+    DenseVector([1.5..., -1.0...])
+    >>> abs(model.intercept - 1.5) < 0.001
+    True
+    >>> glr_path = temp_path + "/glr"
+    >>> glr.save(glr_path)
+    >>> glr2 = GeneralizedLinearRegression.load(glr_path)
+    >>> glr.getFamily() == glr2.getFamily()
+    True
+    >>> model_path = temp_path + "/glr_model"
+    >>> model.save(model_path)
+    >>> model2 = GeneralizedLinearRegressionModel.load(model_path)
+    >>> model.intercept == model2.intercept
+    True
+    >>> model.coefficients[0] == model2.coefficients[0]
+    True
+
+    .. versionadded:: 2.0.0
+    """
+
+    family = Param(Params._dummy(), "family", "The name of family which is a description
of " +
+                   "the error distribution to be used in the model. Supported options: "
+
+                   "gaussian(default), binomial, poisson and gamma.")
+    link = Param(Params._dummy(), "link", "The name of link function which provides the "
+
+                 "relationship between the linear predictor and the mean of the distribution
" +
+                 "function. Supported options: identity, log, inverse, logit, probit, cloglog
" +
+                 "and sqrt.")
+
+    @keyword_only
+    def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
+                 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
+                 regParam=0.0, weightCol=None, solver="irls"):
+        """
+        __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
\
+                 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
+                 regParam=0.0, weightCol=None, solver="irls")
+        """
+        super(GeneralizedLinearRegression, self).__init__()
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
+        self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    @since("2.0.0")
+    def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
+                  family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
+                  regParam=0.0, weightCol=None, solver="irls"):
+        """
+        setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
\
+                  family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
\
+                  regParam=0.0, weightCol=None, solver="irls")
+        Sets params for generalized linear regression.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    def _create_model(self, java_model):
+        return GeneralizedLinearRegressionModel(java_model)
+
+    @since("2.0.0")
+    def setFamily(self, value):
+        """
+        Sets the value of :py:attr:`family`.
+        """
+        self._paramMap[self.family] = value
+        return self
+
+    @since("2.0.0")
+    def getFamily(self):
+        """
+        Gets the value of family or its default value.
+        """
+        return self.getOrDefault(self.family)
+
+    @since("2.0.0")
+    def setLink(self, value):
+        """
+        Sets the value of :py:attr:`link`.
+        """
+        self._paramMap[self.link] = value
+        return self
+
+    @since("2.0.0")
+    def getLink(self):
+        """
+        Gets the value of link or its default value.
+        """
+        return self.getOrDefault(self.link)
+
+
+class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+    """
+    Model fitted by GeneralizedLinearRegression.
+
+    .. versionadded:: 2.0.0
+    """
+
+    @property
+    @since("2.0.0")
+    def coefficients(self):
+        """
+        Model coefficients.
+        """
+        return self._call_java("coefficients")
+
+    @property
+    @since("2.0.0")
+    def intercept(self):
+        """
+        Model intercept.
+        """
+        return self._call_java("intercept")
+
+
 if __name__ == "__main__":
     import doctest
     import pyspark.ml.regression


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message