Repository: spark
Updated Branches:
  refs/heads/master 21fd12cb1 -> 922338812


[SPARK-9681] [ML] Support R feature interactions in RFormula

This integrates the Interaction feature transformer with SparkR R formula 
support (i.e. support `:`).

To generate reasonable ML attribute names for feature interactions, it was 
necessary to add the ability to read attribute the original attribute names 
back from `StructField`, and also to specify custom group prefixes in 
`VectorAssembler`. This also has the side-benefit of cleaning up the 
double-underscores in the attributes generated for non-interaction terms.

mengxr

Author: Eric Liang <[email protected]>

Closes #8830 from ericl/interaction-2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92233881
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92233881
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92233881

Branch: refs/heads/master
Commit: 922338812c03eba43f2f1a6c414d1b6b049811cf
Parents: 21fd12c
Author: Eric Liang <[email protected]>
Authored: Fri Sep 25 00:43:22 2015 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Fri Sep 25 00:43:22 2015 -0700

----------------------------------------------------------------------
 R/pkg/R/mllib.R                                 |   2 +-
 R/pkg/inst/tests/test_mllib.R                   |  10 +-
 .../apache/spark/ml/attribute/attributes.scala  |  16 ++-
 .../apache/spark/ml/feature/Interaction.scala   |  12 +-
 .../org/apache/spark/ml/feature/RFormula.scala  | 113 +++++++++++++++----
 .../spark/ml/feature/RFormulaParser.scala       |  97 ++++++++++++----
 .../apache/spark/ml/feature/StringIndexer.scala |   5 +-
 .../spark/ml/feature/RFormulaParserSuite.scala  |  89 ++++++++++++++-
 .../apache/spark/ml/feature/RFormulaSuite.scala |  76 ++++++++++++-
 python/pyspark/ml/feature.py                    |   2 +-
 10 files changed, 362 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index cea3d76..474ada5 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj"))
 #' Fits a generalized linear model, similarly to R's glm(). Also see the 
glmnet package.
 #'
 #' @param formula A symbolic description of the model to be fitted. Currently 
only a few formula
-#'                operators are supported, including '~', '+', '-', and '.'.
+#'                operators are supported, including '~', '.', ':', '+', and 
'-'.
 #' @param data DataFrame for training
 #' @param family Error distribution. "gaussian" -> linear regression, 
"binomial" -> logistic reg.
 #' @param lambda Regularization parameter

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/R/pkg/inst/tests/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R
index f272de7..032f8ec 100644
--- a/R/pkg/inst/tests/test_mllib.R
+++ b/R/pkg/inst/tests/test_mllib.R
@@ -49,6 +49,14 @@ test_that("dot minus and intercept vs native glm", {
   expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
 })
 
+test_that("feature interaction vs native glm", {
+  training <- createDataFrame(sqlContext, iris)
+  model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
+  vals <- collect(select(predict(model, training), "prediction"))
+  rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
+  expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
+})
+
 test_that("summary coefficients match with native glm", {
   training <- createDataFrame(sqlContext, iris)
   stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
@@ -57,5 +65,5 @@ test_that("summary coefficients match with native glm", {
   expect_true(all(abs(rCoefs - coefs) < 1e-6))
   expect_true(all(
     as.character(stats$features) ==
-    c("(Intercept)", "Sepal_Length", "Species__versicolor", 
"Species__virginica")))
+    c("(Intercept)", "Sepal_Length", "Species_versicolor", 
"Species_virginica")))
 })

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index e479f16..a7c1033 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -124,18 +124,28 @@ private[attribute] trait AttributeFactory {
   private[attribute] def fromMetadata(metadata: Metadata): Attribute
 
   /**
-   * Creates an [[Attribute]] from a [[StructField]] instance.
+   * Creates an [[Attribute]] from a [[StructField]] instance, optionally 
preserving name.
    */
-  def fromStructField(field: StructField): Attribute = {
+  private[ml] def decodeStructField(field: StructField, preserveName: 
Boolean): Attribute = {
     require(field.dataType.isInstanceOf[NumericType])
     val metadata = field.metadata
     val mlAttr = AttributeKeys.ML_ATTR
     if (metadata.contains(mlAttr)) {
-      fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name)
+      val attr = fromMetadata(metadata.getMetadata(mlAttr))
+      if (preserveName) {
+        attr
+      } else {
+        attr.withName(field.name)
+      }
     } else {
       UnresolvedAttribute
     }
   }
+
+  /**
+   * Creates an [[Attribute]] from a [[StructField]] instance.
+   */
+  def fromStructField(field: StructField): Attribute = 
decodeStructField(field, false)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 9194763..37f7862 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -149,8 +149,14 @@ class Interaction(override val uid: String) extends 
Transformer
     features.reverse.foreach { f =>
       val encodedAttrs = f.dataType match {
         case _: NumericType | BooleanType =>
-          val attr = Attribute.fromStructField(f)
-          encodedFeatureAttrs(Seq(attr), None)
+          val attr = Attribute.decodeStructField(f, preserveName = true)
+          if (attr == UnresolvedAttribute) {
+            
encodedFeatureAttrs(Seq(NumericAttribute.defaultAttr.withName(f.name)), None)
+          } else if (!attr.name.isDefined) {
+            encodedFeatureAttrs(Seq(attr.withName(f.name)), None)
+          } else {
+            encodedFeatureAttrs(Seq(attr), None)
+          }
         case _: VectorUDT =>
           val group = AttributeGroup.fromStructField(f)
           encodedFeatureAttrs(group.attributes.get, Some(group.name))
@@ -221,7 +227,7 @@ class Interaction(override val uid: String) extends 
Transformer
  *                    count is equal to the number of categories. For numeric 
features the count
  *                    should be set to 1.
  */
-private[ml] class FeatureEncoder(numFeatures: Array[Int]) {
+private[ml] class FeatureEncoder(numFeatures: Array[Int]) extends Serializable 
{
   assert(numFeatures.forall(_ > 0), "Features counts must all be positive.")
 
   /** The size of the output vector. */

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index dcd6fe3..f9b8400 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, 
PipelineStage, Transformer}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
@@ -42,8 +43,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol 
with HasLabelCol {
 /**
  * :: Experimental ::
  * Implements the transforms required for fitting a dataset against an R model 
formula. Currently
- * we support a limited subset of the R operators, including '.', '~', '+', 
and '-'. Also see the
- * R formula docs here: 
http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+ * we support a limited subset of the R operators, including '~', '.', ':', 
'+', and '-'. Also see
+ * the R formula docs here: 
http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
  */
 @Experimental
 class RFormula(override val uid: String) extends Estimator[RFormulaModel] with 
RFormulaBase {
@@ -82,36 +83,54 @@ class RFormula(override val uid: String) extends 
Estimator[RFormulaModel] with R
     require(isDefined(formula), "Formula must be defined first.")
     val parsedFormula = RFormulaParser.parse($(formula))
     val resolvedFormula = parsedFormula.resolve(dataset.schema)
-    // StringType terms and terms representing interactions need to be encoded 
before assembly.
-    // TODO(ekl) add support for feature interactions
     val encoderStages = ArrayBuffer[PipelineStage]()
+
+    val prefixesToRewrite = mutable.Map[String, String]()
     val tempColumns = ArrayBuffer[String]()
-    val takenNames = mutable.Set(dataset.columns: _*)
-    val encodedTerms = resolvedFormula.terms.map { term =>
+    def tmpColumn(category: String): String = {
+      val col = Identifiable.randomUID(category)
+      tempColumns += col
+      col
+    }
+
+    // First we index each string column referenced by the input terms.
+    val indexed: Map[String, String] = 
resolvedFormula.terms.flatten.distinct.map { term =>
       dataset.schema(term) match {
         case column if column.dataType == StringType =>
-          val indexCol = term + "_idx_" + uid
-          val encodedCol = {
-            var tmp = term
-            while (takenNames.contains(tmp)) {
-              tmp += "_"
-            }
-            tmp
-          }
-          takenNames.add(indexCol)
-          takenNames.add(encodedCol)
-          encoderStages += new 
StringIndexer().setInputCol(term).setOutputCol(indexCol)
-          encoderStages += new 
OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol)
-          tempColumns += indexCol
-          tempColumns += encodedCol
-          encodedCol
+          val indexCol = tmpColumn("stridx")
+          encoderStages += new StringIndexer()
+            .setInputCol(term)
+            .setOutputCol(indexCol)
+          (term, indexCol)
         case _ =>
-          term
+          (term, term)
       }
+    }.toMap
+
+    // Then we handle one-hot encoding and interactions between terms.
+    val encodedTerms = resolvedFormula.terms.map {
+      case Seq(term) if dataset.schema(term).dataType == StringType =>
+        val encodedCol = tmpColumn("onehot")
+        encoderStages += new OneHotEncoder()
+          .setInputCol(indexed(term))
+          .setOutputCol(encodedCol)
+        prefixesToRewrite(encodedCol + "_") = term + "_"
+        encodedCol
+      case Seq(term) =>
+        term
+      case terms =>
+        val interactionCol = tmpColumn("interaction")
+        encoderStages += new Interaction()
+          .setInputCols(terms.map(indexed).toArray)
+          .setOutputCol(interactionCol)
+        prefixesToRewrite(interactionCol + "_") = ""
+        interactionCol
     }
+
     encoderStages += new VectorAssembler(uid)
       .setInputCols(encodedTerms.toArray)
       .setOutputCol($(featuresCol))
+    encoderStages += new VectorAttributeRewriter($(featuresCol), 
prefixesToRewrite.toMap)
     encoderStages += new ColumnPruner(tempColumns.toSet)
     val pipelineModel = new 
Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
     copyValues(new RFormulaModel(uid, resolvedFormula, 
pipelineModel).setParent(this))
@@ -218,3 +237,53 @@ private class ColumnPruner(columnsToPrune: Set[String]) 
extends Transformer {
 
   override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra)
 }
+
+/**
+ * Utility transformer that rewrites Vector attribute names via prefix 
replacement. For example,
+ * it can rewrite attribute names starting with 'foo_' to start with 'bar_' 
instead.
+ *
+ * @param vectorCol name of the vector column to rewrite.
+ * @param prefixesToRewrite the map of string prefixes to their replacement 
values. Each attribute
+ *                          name defined in vectorCol will be checked against 
the keys of this
+ *                          map. When a key prefixes a name, the matching 
prefix will be replaced
+ *                          by the value in the map.
+ */
+private class VectorAttributeRewriter(
+    vectorCol: String,
+    prefixesToRewrite: Map[String, String])
+  extends Transformer {
+
+  override val uid = Identifiable.randomUID("vectorAttrRewriter")
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    val metadata = {
+      val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
+      val attrs = group.attributes.get.map { attr =>
+        if (attr.name.isDefined) {
+          val name = attr.name.get
+          val replacement = prefixesToRewrite.filter { case (k, _) => 
name.startsWith(k) }
+          if (replacement.nonEmpty) {
+            val (k, v) = replacement.headOption.get
+            attr.withName(v + name.stripPrefix(k))
+          } else {
+            attr
+          }
+        } else {
+          attr
+        }
+      }
+      new AttributeGroup(vectorCol, attrs).toMetadata()
+    }
+    val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
+    val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
+    dataset.select((otherCols :+ rewrittenCol): _*)
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    StructType(
+      schema.fields.filter(_.name != vectorCol) ++
+      schema.fields.filter(_.name == vectorCol))
+  }
+
+  override def copy(extra: ParamMap): VectorAttributeRewriter = 
defaultCopy(extra)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
index 1ca3b92..4079b38 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.feature
 
+import scala.collection.mutable
 import scala.util.parsing.combinator.RegexParsers
 
 import org.apache.spark.mllib.linalg.VectorUDT
@@ -31,27 +32,35 @@ private[ml] case class ParsedRFormula(label: ColumnRef, 
terms: Seq[Term]) {
    * of the special '.' term. Duplicate terms will be removed during 
resolution.
    */
   def resolve(schema: StructType): ResolvedRFormula = {
-    var includedTerms = Seq[String]()
+    val dotTerms = expandDot(schema)
+    var includedTerms = Seq[Seq[String]]()
     terms.foreach {
+      case col: ColumnRef =>
+        includedTerms :+= Seq(col.value)
+      case ColumnInteraction(cols) =>
+        includedTerms ++= expandInteraction(schema, cols)
       case Dot =>
-        includedTerms ++= simpleTypes(schema).filter(_ != label.value)
-      case ColumnRef(value) =>
-        includedTerms :+= value
+        includedTerms ++= dotTerms.map(Seq(_))
       case Deletion(term: Term) =>
         term match {
-          case ColumnRef(value) =>
-            includedTerms = includedTerms.filter(_ != value)
+          case inner: ColumnRef =>
+            includedTerms = includedTerms.filter(_ != Seq(inner.value))
+          case ColumnInteraction(cols) =>
+            val fromInteraction = expandInteraction(schema, cols).map(_.toSet)
+            includedTerms = includedTerms.filter(t => 
!fromInteraction.contains(t.toSet))
           case Dot =>
             // e.g. "- .", which removes all first-order terms
-            val fromSchema = simpleTypes(schema)
-            includedTerms = includedTerms.filter(fromSchema.contains(_))
+            includedTerms = includedTerms.filter {
+              case Seq(t) => !dotTerms.contains(t)
+              case _ => true
+            }
           case _: Deletion =>
-            assert(false, "Deletion terms cannot be nested")
+            throw new RuntimeException("Deletion terms cannot be nested")
           case _: Intercept =>
         }
       case _: Intercept =>
     }
-    ResolvedRFormula(label.value, includedTerms.distinct)
+    ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
   }
 
   /** Whether this formula specifies fitting with an intercept term. */
@@ -67,19 +76,54 @@ private[ml] case class ParsedRFormula(label: ColumnRef, 
terms: Seq[Term]) {
     intercept
   }
 
+  // expands the Dot operators in interaction terms
+  private def expandInteraction(
+      schema: StructType, terms: Seq[InteractableTerm]): Seq[Seq[String]] = {
+    if (terms.isEmpty) {
+      return Seq(Nil)
+    }
+
+    val rest = expandInteraction(schema, terms.tail)
+    val validInteractions = (terms.head match {
+      case Dot =>
+        expandDot(schema).flatMap { t =>
+          rest.map { r =>
+            Seq(t) ++ r
+          }
+        }
+      case ColumnRef(value) =>
+        rest.map(Seq(value) ++ _)
+    }).map(_.distinct)
+
+    // Deduplicates feature interactions, for example, a:b is the same as b:a.
+    var seen = mutable.Set[Set[String]]()
+    validInteractions.flatMap {
+      case t if seen.contains(t.toSet) =>
+        None
+      case t =>
+        seen += t.toSet
+        Some(t)
+    }.sortBy(_.length)
+  }
+
   // the dot operator excludes complex column types
-  private def simpleTypes(schema: StructType): Seq[String] = {
+  private def expandDot(schema: StructType): Seq[String] = {
     schema.fields.filter(_.dataType match {
       case _: NumericType | StringType | BooleanType | _: VectorUDT => true
       case _ => false
-    }).map(_.name)
+    }).map(_.name).filter(_ != label.value)
   }
 }
 
 /**
  * Represents a fully evaluated and simplified R formula.
+ * @param label the column name of the R formula label (response variable).
+ * @param terms the simplified terms of the R formula. Interactions terms are 
represented as Seqs
+ *              of column names; non-interaction terms as length 1 Seqs.
+ * @param hasIntercept whether the formula specifies fitting with an intercept.
  */
-private[ml] case class ResolvedRFormula(label: String, terms: Seq[String])
+private[ml] case class ResolvedRFormula(
+  label: String, terms: Seq[Seq[String]], hasIntercept: Boolean)
 
 /**
  * R formula terms. See the R formula docs here for more information:
@@ -87,11 +131,17 @@ private[ml] case class ResolvedRFormula(label: String, 
terms: Seq[String])
  */
 private[ml] sealed trait Term
 
+/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */
+private[ml] sealed trait InteractableTerm extends Term
+
 /* R formula reference to all available columns, e.g. "." in a formula */
-private[ml] case object Dot extends Term
+private[ml] case object Dot extends InteractableTerm
 
 /* R formula reference to a column, e.g. "+ Species" in a formula */
-private[ml] case class ColumnRef(value: String) extends Term
+private[ml] case class ColumnRef(value: String) extends InteractableTerm
+
+/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a 
formula */
+private[ml] case class ColumnInteraction(terms: Seq[InteractableTerm]) extends 
Term
 
 /* R formula intercept toggle, e.g. "+ 0" in a formula */
 private[ml] case class Intercept(enabled: Boolean) extends Term
@@ -100,25 +150,30 @@ private[ml] case class Intercept(enabled: Boolean) 
extends Term
 private[ml] case class Deletion(term: Term) extends Term
 
 /**
- * Limited implementation of R formula parsing. Currently supports: '~', '+', 
'-', '.'.
+ * Limited implementation of R formula parsing. Currently supports: '~', '+', 
'-', '.', ':'.
  */
 private[ml] object RFormulaParser extends RegexParsers {
-  def intercept: Parser[Intercept] =
+  private val intercept: Parser[Intercept] =
     "([01])".r ^^ { case a => Intercept(a == "1") }
 
-  def columnRef: Parser[ColumnRef] =
+  private val columnRef: Parser[ColumnRef] =
     "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) }
 
-  def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot }
+  private val dot: Parser[InteractableTerm] = "\\.".r ^^ { case _ => Dot }
+
+  private val interaction: Parser[List[InteractableTerm]] = rep1sep(columnRef 
| dot, ":")
+
+  private val term: Parser[Term] = intercept |
+    interaction ^^ { case terms => ColumnInteraction(terms) } | dot | columnRef
 
-  def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ {
+  private val terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ 
term)) ^^ {
     case op ~ list => list.foldLeft(List(op)) {
       case (left, "+" ~ right) => left ++ Seq(right)
       case (left, "-" ~ right) => left ++ Seq(Deletion(right))
     }
   }
 
-  def formula: Parser[ParsedRFormula] =
+  private val formula: Parser[ParsedRFormula] =
     (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) }
 
   def parse(value: String): ParsedRFormula = parseAll(formula, value) match {

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 2b15929..486274c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -147,9 +147,8 @@ class StringIndexerModel (
       }
     }
 
-    val outputColName = $(outputCol)
     val metadata = NominalAttribute.defaultAttr
-      .withName(outputColName).withValues(labels).toMetadata()
+      .withName($(inputCol)).withValues(labels).toMetadata()
     // If we are skipping invalid records, filter them out.
     val filteredDataset = (getHandleInvalid) match {
       case "skip" => {
@@ -161,7 +160,7 @@ class StringIndexerModel (
       case _ => dataset
     }
     filteredDataset.select(col("*"),
-      indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, 
metadata))
+      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), 
metadata))
   }
 
   override def transformSchema(schema: StructType): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
index 436e66b..53798c6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala
@@ -25,16 +25,24 @@ class RFormulaParserSuite extends SparkFunSuite {
       formula: String,
       label: String,
       terms: Seq[String],
-      schema: StructType = null) {
+      schema: StructType = new StructType) {
     val resolved = RFormulaParser.parse(formula).resolve(schema)
     assert(resolved.label == label)
-    assert(resolved.terms == terms)
+    val simpleTerms = terms.map { t =>
+      if (t.contains(":")) {
+        t.split(":").toSeq
+      } else {
+        Seq(t)
+      }
+    }
+    assert(resolved.terms == simpleTerms)
   }
 
   test("parse simple formulas") {
     checkParse("y ~ x", "y", Seq("x"))
     checkParse("y ~ x + x", "y", Seq("x"))
-    checkParse("y ~   ._foo  ", "y", Seq("._foo"))
+    checkParse("y~x+z", "y", Seq("x", "z"))
+    checkParse("y ~   ._fo..o  ", "y", Seq("._fo..o"))
     checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123"))
   }
 
@@ -79,4 +87,79 @@ class RFormulaParserSuite extends SparkFunSuite {
     assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept)
     assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept)
   }
+
+  test("parse interactions") {
+    checkParse("y ~ a:b", "y", Seq("a:b"))
+    checkParse("y ~ ._a:._x", "y", Seq("._a:._x"))
+    checkParse("y ~ foo:bar", "y", Seq("foo:bar"))
+    checkParse("y ~ a : b : c", "y", Seq("a:b:c"))
+    checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", 
"c:d", "z"))
+  }
+
+  test("parse basic interactions with dot") {
+    val schema = (new StructType)
+      .add("a", "int", true)
+      .add("b", "long", false)
+      .add("c", "string", true)
+      .add("d", "string", true)
+    checkParse("a ~ .:b", "a", Seq("b", "c:b", "d:b"), schema)
+    checkParse("a ~ b:.", "a", Seq("b", "b:c", "b:d"), schema)
+    checkParse("a ~ .:b:.:.:c:d:.", "a", Seq("b:c:d"), schema)
+  }
+
+  // Test data generated in R with terms.formula(y ~ .:., data = iris)
+  test("parse all to all iris interactions") {
+    val schema = (new StructType)
+      .add("Sepal.Length", "double", true)
+      .add("Sepal.Width", "double", true)
+      .add("Petal.Length", "double", true)
+      .add("Petal.Width", "double", true)
+      .add("Species", "string", true)
+    checkParse(
+      "y ~ .:.",
+      "y",
+      Seq(
+        "Sepal.Length",
+        "Sepal.Width",
+        "Petal.Length",
+        "Petal.Width",
+        "Species",
+        "Sepal.Length:Sepal.Width",
+        "Sepal.Length:Petal.Length",
+        "Sepal.Length:Petal.Width",
+        "Sepal.Length:Species",
+        "Sepal.Width:Petal.Length",
+        "Sepal.Width:Petal.Width",
+        "Sepal.Width:Species",
+        "Petal.Length:Petal.Width",
+        "Petal.Length:Species",
+        "Petal.Width:Species"),
+      schema)
+  }
+
+  // Test data generated in R with terms.formula(y ~ .:. - Species:., data = 
iris)
+  test("parse interaction negation with iris") {
+    val schema = (new StructType)
+      .add("Sepal.Length", "double", true)
+      .add("Sepal.Width", "double", true)
+      .add("Petal.Length", "double", true)
+      .add("Petal.Width", "double", true)
+      .add("Species", "string", true)
+    checkParse("y ~ .:. - .:.", "y", Nil, schema)
+    checkParse(
+      "y ~ .:. - Species:.",
+      "y",
+      Seq(
+        "Sepal.Length",
+        "Sepal.Width",
+        "Petal.Length",
+        "Petal.Width",
+        "Sepal.Length:Sepal.Width",
+        "Sepal.Length:Petal.Length",
+        "Sepal.Length:Petal.Width",
+        "Sepal.Width:Petal.Length",
+        "Sepal.Width:Petal.Width",
+        "Petal.Length:Petal.Width"),
+      schema)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 6aed324..b560130 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -118,9 +118,81 @@ class RFormulaSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val expectedAttrs = new AttributeGroup(
       "features",
       Array(
-        new BinaryAttribute(Some("a__bar"), Some(1)),
-        new BinaryAttribute(Some("a__foo"), Some(2)),
+        new BinaryAttribute(Some("a_bar"), Some(1)),
+        new BinaryAttribute(Some("a_foo"), Some(2)),
         new NumericAttribute(Some("b"), Some(3))))
     assert(attrs === expectedAttrs)
   }
+
+  test("numeric interaction") {
+    val formula = new RFormula().setFormula("a ~ b:c:d")
+    val original = sqlContext.createDataFrame(
+      Seq((1, 2, 4, 2), (2, 3, 4, 1))
+    ).toDF("a", "b", "c", "d")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val expected = sqlContext.createDataFrame(
+      Seq(
+        (1, 2, 4, 2, Vectors.dense(16.0), 1.0),
+        (2, 3, 4, 1, Vectors.dense(12.0), 2.0))
+      ).toDF("a", "b", "c", "d", "features", "label")
+    assert(result.collect() === expected.collect())
+    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expectedAttrs = new AttributeGroup(
+      "features",
+      Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1))))
+    assert(attrs === expectedAttrs)
+  }
+
+  test("factor numeric interaction") {
+    val formula = new RFormula().setFormula("id ~ a:b")
+    val original = sqlContext.createDataFrame(
+      Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, 
"baz", 5), (4, "baz", 5))
+    ).toDF("id", "a", "b")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val expected = sqlContext.createDataFrame(
+      Seq(
+        (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
+        (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0),
+        (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0),
+        (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
+        (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0),
+        (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0))
+      ).toDF("id", "a", "b", "features", "label")
+    assert(result.collect() === expected.collect())
+    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expectedAttrs = new AttributeGroup(
+      "features",
+      Array[Attribute](
+        new NumericAttribute(Some("a_baz:b"), Some(1)),
+        new NumericAttribute(Some("a_bar:b"), Some(2)),
+        new NumericAttribute(Some("a_foo:b"), Some(3))))
+    assert(attrs === expectedAttrs)
+  }
+
+  test("factor factor interaction") {
+    val formula = new RFormula().setFormula("id ~ a:b")
+    val original = sqlContext.createDataFrame(
+      Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz"))
+    ).toDF("id", "a", "b")
+    val model = formula.fit(original)
+    val result = model.transform(original)
+    val expected = sqlContext.createDataFrame(
+      Seq(
+        (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0),
+        (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0),
+        (3, "bar", "zz", Vectors.dense(0.0, 1.0, 0.0, 0.0), 3.0))
+      ).toDF("id", "a", "b", "features", "label")
+    assert(result.collect() === expected.collect())
+    val attrs = AttributeGroup.fromStructField(result.schema("features"))
+    val expectedAttrs = new AttributeGroup(
+      "features",
+      Array[Attribute](
+        new NumericAttribute(Some("a_bar:b_zq"), Some(1)),
+        new NumericAttribute(Some("a_bar:b_zz"), Some(2)),
+        new NumericAttribute(Some("a_foo:b_zq"), Some(3)),
+        new NumericAttribute(Some("a_foo:b_zz"), Some(4))))
+    assert(attrs === expectedAttrs)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/92233881/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f41d72f..a4e60f9 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -1850,7 +1850,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, 
HasLabelCol):
 
     Implements the transforms required for fitting a dataset against an
     R model formula. Currently we support a limited subset of the R
-    operators, including '~', '+', '-', and '.'. Also see the R formula
+    operators, including '~', '.', ':', '+', and '-'. Also see the R formula
     docs:
     http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to