Repository: spark Updated Branches: refs/heads/master f5ff4a84c -> 65afd3ce8
[SPARK-7474] [MLLIB] update ParamGridBuilder doctest Multiline commands are properly handled in this PR. oefirouz  Author: Xiangrui Meng <[email protected]> Closes #6001 from mengxr/SPARK-7474 and squashes the following commits: b94b11d [Xiangrui Meng] update ParamGridBuilder doctest Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/65afd3ce Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/65afd3ce Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/65afd3ce Branch: refs/heads/master Commit: 65afd3ce8b8a0b00f4ea8294eac14b72e964872d Parents: f5ff4a8 Author: Xiangrui Meng <[email protected]> Authored: Fri May 8 11:16:04 2015 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Fri May 8 11:16:04 2015 -0700 ---------------------------------------------------------------------- python/pyspark/ml/tuning.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/65afd3ce/python/pyspark/ml/tuning.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 1e04c37..28e3727 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -27,24 +27,22 @@ __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel'] class ParamGridBuilder(object): - """ + r""" Builder for a param grid used in grid search-based model selection. - >>> from classification import LogisticRegression + >>> from pyspark.ml.classification import LogisticRegression >>> lr = LogisticRegression() - >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \ - .baseOn([lr.predictionCol, 'p']) \ - .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \ - .addGrid(lr.maxIter, [1, 5]) \ - .addGrid(lr.featuresCol, ['f']) \ - .build() - >>> expected = [ \ -{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ -{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ -{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ -{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ -{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ -{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] + >>> output = ParamGridBuilder() \ + ... .baseOn({lr.labelCol: 'l'}) \ + ... .baseOn([lr.predictionCol, 'p']) \ + ... .addGrid(lr.regParam, [1.0, 2.0]) \ + ... .addGrid(lr.maxIter, [1, 5]) \ + ... .build() + >>> expected = [ + ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, + ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, + ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, + ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] >>> len(output) == len(expected) True >>> all([m in expected for m in output]) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
