spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-8965] [DOCS] Add ml-guide Python Example: Estimator, Transformer, and Param
Date Thu, 13 Aug 2015 16:18:44 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2932e25da -> 7a539ef3b


[SPARK-8965] [DOCS] Add ml-guide Python Example: Estimator, Transformer, and Param

Added ml-guide Python Example: Estimator, Transformer, and Param
/docs/_site/ml-guide.html

Author: Rosstin <asterazul@gmail.com>

Closes #8081 from Rosstin/SPARK-8965.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7a539ef3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7a539ef3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7a539ef3

Branch: refs/heads/master
Commit: 7a539ef3b1792764f866fa88c84c78ad59903f21
Parents: 2932e25
Author: Rosstin <asterazul@gmail.com>
Authored: Thu Aug 13 09:18:39 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Aug 13 09:18:39 2015 -0700

----------------------------------------------------------------------
 docs/ml-guide.md | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 68 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7a539ef3/docs/ml-guide.md
----------------------------------------------------------------------
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index b6ca50e..a03ab43 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -355,6 +355,74 @@ jsc.stop();
 {% endhighlight %}
 </div>
 
+<div data-lang="python">
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.param import Param, Params
+from pyspark.sql import Row, SQLContext
+
+sc = SparkContext(appName="SimpleParamsExample")
+sqlContext = SQLContext(sc)
+
+# Prepare training data.
+# We use LabeledPoint.
+# Spark SQL can convert RDDs of LabeledPoints into DataFrames.
+training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1,  0.1]),
+                           LabeledPoint(0.0, [2.0, 1.0, -1.0]),
+                           LabeledPoint(0.0, [2.0, 1.3,  1.0]),
+                           LabeledPoint(1.0, [0.0, 1.2, -0.5])])
+
+# Create a LogisticRegression instance. This instance is an Estimator.
+lr = LogisticRegression(maxIter=10, regParam=0.01)
+# Print out the parameters, documentation, and any default values.
+print "LogisticRegression parameters:\n" + lr.explainParams() + "\n"
+
+# Learn a LogisticRegression model. This uses the parameters stored in lr.
+model1 = lr.fit(training.toDF())
+
+# Since model1 is a Model (i.e., a transformer produced by an Estimator),
+# we can view the parameters it used during fit().
+# This prints the parameter (name: value) pairs, where names are unique IDs for this
+# LogisticRegression instance.
+print "Model 1 was fit using parameters: "
+print model1.extractParamMap()
+
+# We may alternatively specify parameters using a Python dictionary as a paramMap
+paramMap = {lr.maxIter: 20}
+paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter.
+paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params.
+
+# You can combine paramMaps, which are python dictionaries.
+paramMap2 = {lr.probabilityCol: "myProbability"} # Change output column name
+paramMapCombined = paramMap.copy()
+paramMapCombined.update(paramMap2)
+
+# Now learn a new model using the paramMapCombined parameters.
+# paramMapCombined overrides all parameters set earlier via lr.set* methods.
+model2 = lr.fit(training.toDF(), paramMapCombined)
+print "Model 2 was fit using parameters: "
+print model2.extractParamMap()
+
+# Prepare test data
+test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5,  1.3]),
+                       LabeledPoint(0.0, [ 3.0, 2.0, -0.1]),
+                       LabeledPoint(1.0, [ 0.0, 2.2, -1.5])])
+
+# Make predictions on test data using the Transformer.transform() method.
+# LogisticRegression.transform will only use the 'features' column.
+# Note that model2.transform() outputs a "myProbability" column instead of the usual
+# 'probability' column since we renamed the lr.probabilityCol parameter previously.
+prediction = model2.transform(test.toDF())
+selected = prediction.select("features", "label", "myProbability", "prediction")
+for row in selected.collect():
+    print row
+
+sc.stop()
+{% endhighlight %}
+</div>
+
 </div>
 
 ## Example: Pipeline


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message