http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala deleted file mode 100644 index c188c5b..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala +++ /dev/null @@ -1,933 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import java.sql.Date - -import scala.collection.mutable.ArrayBuffer -import scala.util.matching.Regex - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler - - -/** - * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s. - */ -private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface { - import ParserUtils._ - - /** - * The safeParse method allows a user to focus on the parsing/AST transformation logic. This - * method will take care of possible errors during the parsing process. - */ - protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = { - try { - toResult(ast) - } catch { - case e: MatchError => throw e - case e: AnalysisException => throw e - case e: Exception => - throw new AnalysisException(e.getMessage) - case e: NotImplementedError => - throw new AnalysisException( - s"""Unsupported language features in query - |== SQL == - |$sql - |== AST == - |${ast.treeString} - |== Error == - |$e - |== Stacktrace == - |${e.getStackTrace.head} - """.stripMargin) - } - } - - /** Creates LogicalPlan for a given SQL string. */ - def parsePlan(sql: String): LogicalPlan = - safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan) - - /** Creates Expression for a given SQL string. */ - def parseExpression(sql: String): Expression = - safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get) - - /** Creates TableIdentifier for a given SQL string. */ - def parseTableIdentifier(sql: String): TableIdentifier = - safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent) - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition { - case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets - case _ => true // grouping keys - } - - val keys = keyASTs.map(nodeToExpr) - val keyMap = keyASTs.zipWithIndex.toMap - - val mask = (1 << keys.length) - 1 - val bitmasks: Seq[Int] = setASTs.map { - case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => - columns.foldLeft(mask)((bitmap, col) => { - val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse( - throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (keys.length - 1 - keyIndex)) - }) - case _ => sys.error("Expect GROUPING SETS clause") - } - - (keys, bitmasks) - } - - protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { - case Token("TOK_SHOWFUNCTIONS", args) => - // Skip LIKE. - val pattern = args match { - case like :: nodes if like.text.toUpperCase == "LIKE" => nodes - case nodes => nodes - } - - // Extract Database and Function name - pattern match { - case Nil => - ShowFunctions(None, None) - case Token(name, Nil) :: Nil => - ShowFunctions(None, Some(unquoteString(cleanIdentifier(name)))) - case Token(db, Nil) :: Token(name, Nil) :: Nil => - ShowFunctions(Some(unquoteString(cleanIdentifier(db))), - Some(unquoteString(cleanIdentifier(name)))) - case _ => - noParseRule("SHOW FUNCTIONS", node) - } - - case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => - DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty) - - case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => - val (fromClause: Option[ASTNode], insertClauses, cteRelations) = - queryArgs match { - case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => - val cteRelations = ctes.map { node => - val relation = nodeToRelation(node).asInstanceOf[SubqueryAlias] - relation.alias -> relation - } - (Some(from.head), inserts, Some(cteRelations.toMap)) - case Token("TOK_FROM", from) :: inserts => - (Some(from.head), inserts, None) - case Token("TOK_INSERT", _) :: Nil => - (None, queryArgs, None) - } - - // Return one query for each insert clause. - val queries = insertClauses.map { - case Token("TOK_INSERT", singleInsert) => - val ( - intoClause :: - destClause :: - selectClause :: - selectDistinctClause :: - whereClause :: - groupByClause :: - rollupGroupByClause :: - cubeGroupByClause :: - groupingSetsClause :: - orderByClause :: - havingClause :: - sortByClause :: - clusterByClause :: - distributeByClause :: - limitClause :: - lateralViewClause :: - windowClause :: Nil) = { - getClauses( - Seq( - "TOK_INSERT_INTO", - "TOK_DESTINATION", - "TOK_SELECT", - "TOK_SELECTDI", - "TOK_WHERE", - "TOK_GROUPBY", - "TOK_ROLLUP_GROUPBY", - "TOK_CUBE_GROUPBY", - "TOK_GROUPING_SETS", - "TOK_ORDERBY", - "TOK_HAVING", - "TOK_SORTBY", - "TOK_CLUSTERBY", - "TOK_DISTRIBUTEBY", - "TOK_LIMIT", - "TOK_LATERAL_VIEW", - "WINDOW"), - singleInsert) - } - - val relations = fromClause match { - case Some(f) => nodeToRelation(f) - case None => OneRowRelation - } - - val withLateralView = lateralViewClause.map { lv => - nodeToGenerate(lv.children.head, outer = false, relations) - }.getOrElse(relations) - - val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.children - Filter(nodeToExpr(whereExpr), withLateralView) - }.getOrElse(withLateralView) - - val select = (selectClause orElse selectDistinctClause) - .getOrElse(sys.error("No select clause.")) - - val transformation = nodeToTransformation(select.children.head, withWhere) - - // The projection of the query can either be a normal projection, an aggregation - // (if there is a group by) or a script transformation. - val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withWhere) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withWhere, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Aggregate( - Seq(Rollup(children.map(nodeToExpr))), - selectExpressions, - withWhere) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Aggregate( - Seq(Cube(children.map(nodeToExpr))), - selectExpressions, - withWhere) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withWhere))).flatten.head - } - - // Handle HAVING clause. - val withHaving = havingClause.map { h => - val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) } - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withProject) - }.getOrElse(withProject) - - // Handle SELECT DISTINCT - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withSort = - (orderByClause, sortByClause, distributeByClause, clusterByClause) match { - case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct) - case (None, Some(perPartitionOrdering), None, None) => - Sort( - perPartitionOrdering.children.map(nodeToSortOrder), - global = false, withDistinct) - case (None, None, Some(partitionExprs), None) => - RepartitionByExpression( - partitionExprs.children.map(nodeToExpr), withDistinct) - case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort( - perPartitionOrdering.children.map(nodeToSortOrder), global = false, - RepartitionByExpression( - partitionExprs.children.map(nodeToExpr), - withDistinct)) - case (None, None, None, Some(clusterExprs)) => - Sort( - clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)), - global = false, - RepartitionByExpression( - clusterExprs.children.map(nodeToExpr), - withDistinct)) - case (None, None, None, None) => withDistinct - case _ => sys.error("Unsupported set of ordering / distribution clauses.") - } - - val withLimit = - limitClause.map(l => nodeToExpr(l.children.head)) - .map(Limit(_, withSort)) - .getOrElse(withSort) - - // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.children.collect { - case Token("TOK_WINDOWDEF", - Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - windowName -> nodesToWindowSpecification(spec) - }.toMap) - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val resolvedCrossReference = windowDefinitions.map { - windowDefMap => windowDefMap.map { - case (windowName, WindowSpecReference(other)) => - (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) - case o => o.asInstanceOf[(String, WindowSpecDefinition)] - } - } - - val withWindowDefinitions = - resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) - - // TOK_INSERT_INTO means to add files to the table. - // TOK_DESTINATION means to overwrite the table. - val resultDestination = - (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = intoClause.isEmpty - nodeToDest( - resultDestination, - withWindowDefinitions, - overwrite) - } - - // If there are multiple INSERTS just UNION them together into one query. - val query = if (queries.length == 1) queries.head else Union(queries) - - // return With plan if there is CTE - cteRelations.map(With(query, _)).getOrElse(query) - - case Token("TOK_UNIONALL", left :: right :: Nil) => - Union(nodeToPlan(left), nodeToPlan(right)) - case Token("TOK_UNIONDISTINCT", left :: right :: Nil) => - Distinct(Union(nodeToPlan(left), nodeToPlan(right))) - case Token("TOK_EXCEPT", left :: right :: Nil) => - Except(nodeToPlan(left), nodeToPlan(right)) - case Token("TOK_INTERSECT", left :: right :: Nil) => - Intersect(nodeToPlan(left), nodeToPlan(right)) - - case _ => - noParseRule("Plan", node) - } - - val allJoinTokens = "(TOK_.*JOIN)".r - val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - protected def nodeToRelation(node: ASTNode): LogicalPlan = { - node match { - case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - SubqueryAlias(cleanIdentifier(alias), nodeToPlan(query)) - - case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => - nodeToGenerate( - selectClause, - outer = isOuter.nonEmpty, - nodeToRelation(relationClause)) - - /* All relations, possibly with aliases or sampling clauses. */ - case Token("TOK_TABREF", clauses) => - // If the last clause is not a token then it's the alias of the table. - val (nonAliasClauses, aliasClause) = - if (clauses.last.text.startsWith("TOK")) { - (clauses, None) - } else { - (clauses.dropRight(1), Some(clauses.last)) - } - - val (Some(tableNameParts) :: - splitSampleClause :: - bucketSampleClause :: Nil) = { - getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), - nonAliasClauses) - } - - val tableIdent = extractTableIdent(tableNameParts) - val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(tableIdent, alias) - - // Apply sampling if requested. - (bucketSampleClause orElse splitSampleClause).map { - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => - Limit(Literal(count.toInt), relation) - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - require( - fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) - && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), - s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(0.0, fraction.toDouble / 100, withReplacement = false, - (math.random * 1000).toInt, - relation)( - isTableSample = true) - case Token("TOK_TABLEBUCKETSAMPLE", - Token(numerator, Nil) :: - Token(denominator, Nil) :: Nil) => - val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)( - isTableSample = true) - case a => - noParseRule("Sampling", a) - }.getOrElse(relation) - - case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => - if (!(other.size <= 1)) { - sys.error(s"Unsupported join operation: $other") - } - - val (joinType, joinCondition) = getJoinInfo(joinToken, other, node) - - Join(nodeToRelation(relation1), - nodeToRelation(relation2), - joinType, - joinCondition) - case _ => - noParseRule("Relation", node) - } - } - - protected def getJoinInfo( - joinToken: String, - joinConditionToken: Seq[ASTNode], - node: ASTNode): (JoinType, Option[Expression]) = { - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) - case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) - case "TOK_NATURALJOIN" => NaturalJoin(Inner) - case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) - case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) - case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) - } - - joinConditionToken match { - case Token("TOK_USING", columnList :: Nil) :: Nil => - val colNames = columnList.children.collect { - case Token(name, Nil) => UnresolvedAttribute(name) - } - (UsingJoin(joinType, colNames), None) - /* Join expression specified using ON clause */ - case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr)) - } - } - - protected def nodeToSortOrder(node: ASTNode): SortOrder = node match { - case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Descending) - case _ => - noParseRule("SortOrder", node) - } - - val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r - protected def nodeToDest( - node: ASTNode, - query: LogicalPlan, - overwrite: Boolean): LogicalPlan = node match { - case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => - query - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.children.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false) - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.children.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true) - - case _ => - noParseRule("Destination", node) - } - - protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match { - case Token("TOK_SELEXPR", e :: Nil) => - Some(nodeToExpr(e)) - - case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) - - case Token("TOK_SELEXPR", e :: aliasChildren) => - val aliasNames = aliasChildren.collect { - case Token(name, Nil) => cleanIdentifier(name) - } - Some(MultiAlias(nodeToExpr(e), aliasNames)) - - /* Hints are ignored */ - case Token("TOK_HINTLIST", _) => None - - case _ => - noParseRule("Select", node) - } - - /** - * Flattens the left deep tree with the specified pattern into a list. - */ - private def flattenLeftDeepTree(node: ASTNode, pattern: Regex): Seq[ASTNode] = { - val collected = ArrayBuffer[ASTNode]() - var rest = node - while (rest match { - case Token(pattern(), l :: r :: Nil) => - collected += r - rest = l - true - case _ => false - }) { - // do nothing - } - collected += rest - // keep them in the same order as in SQL - collected.reverse - } - - /** - * Creates a balanced tree that has similar number of nodes on left and right. - * - * This help to reduce the depth of the tree to prevent StackOverflow in analyzer/optimizer. - */ - private def balancedTree( - expr: Seq[Expression], - f: (Expression, Expression) => Expression): Expression = expr.length match { - case 1 => expr.head - case 2 => f(expr.head, expr(1)) - case l => f(balancedTree(expr.slice(0, l / 2), f), balancedTree(expr.slice(l / 2, l), f)) - } - - protected def nodeToExpr(node: ASTNode): Expression = node match { - /* Attribute References */ - case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => - UnresolvedAttribute.quoted(cleanIdentifier(name)) - case Token(".", qualifier :: Token(attr, Nil) :: Nil) => - nodeToExpr(qualifier) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) - } - case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) => - ScalarSubquery(nodeToPlan(subquery)) - - /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) - // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only - // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => - UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text)))) - - /* Aggregate Functions */ - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => - Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => - Count(Literal(1)).toAggregateExpression() - - /* Casts */ - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BooleanType) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) - case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), TimestampType) - case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DateType) - - /* Arithmetic */ - case Token("+", child :: Nil) => nodeToExpr(child) - case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) - case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) - case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) - case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) - case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) - case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => - Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) - case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) - case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) - case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) - case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - - /* Comparisons */ - case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) - case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) - case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => - IsNotNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => - IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => - In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", - Token(BETWEEN(), Nil) :: - kw :: - target :: - minValue :: - maxValue :: Nil) => - - val targetExpression = nodeToExpr(target) - val betweenExpr = - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) - kw match { - case Token("KW_FALSE", Nil) => betweenExpr - case Token("KW_TRUE", Nil) => Not(betweenExpr) - } - - /* Boolean Logic */ - case Token(AND(), left :: right:: Nil) => - balancedTree(flattenLeftDeepTree(node, AND).map(nodeToExpr), And) - case Token(OR(), left :: right:: Nil) => - balancedTree(flattenLeftDeepTree(node, OR).map(nodeToExpr), Or) - case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) - case Token("!", child :: Nil) => Not(nodeToExpr(child)) - - /* Case statements */ - case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen.createFromParser(branches.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val keyExpr = nodeToExpr(branches.head) - CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) - - /* Complex datatype manipulation */ - case Token("[", child :: ordinal :: Nil) => - UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Window Functions */ - case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = nodeToExpr(node.copy(children = node.children.init)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - - /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) - // Aggregate function with DISTINCT keyword. - case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) - - /* Literals */ - case Token("TOK_NULL", Nil) => Literal.create(null, NullType) - case Token(TRUE(), Nil) => Literal.create(true, BooleanType) - case Token(FALSE(), Nil) => Literal.create(false, BooleanType) - case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) - - case ast if ast.tokenType == SparkSqlParser.TinyintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) - - case ast if ast.tokenType == SparkSqlParser.SmallintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) - - case ast if ast.tokenType == SparkSqlParser.BigintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - - case ast if ast.tokenType == SparkSqlParser.DoubleLiteral => - Literal(ast.text.toDouble) - - case ast if ast.tokenType == SparkSqlParser.Number => - val text = ast.text - text match { - case INTEGRAL() => - BigDecimal(text) match { - case v if v.isValidInt => - Literal(v.intValue()) - case v if v.isValidLong => - Literal(v.longValue()) - case v => Literal(v.underlying()) - } - case DECIMAL(_*) => - Literal(BigDecimal(text).underlying()) - case _ => - // Convert a scientifically notated decimal into a double. - Literal(text.toDouble) - } - case ast if ast.tokenType == SparkSqlParser.StringLiteral => - Literal(ParseUtils.unescapeSQLString(ast.text)) - - case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) - - case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.children.head.text)) - - case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.children.head.text)) - - case Token("TOK_INTERVAL", elements) => - var interval = new CalendarInterval(0, 0) - var updated = false - elements.foreach { - // The interval node will always contain children for all possible time units. A child node - // is only useful when it contains exactly one (numeric) child. - case e @ Token(name, Token(value, Nil) :: Nil) => - val unit = name match { - case "TOK_INTERVAL_YEAR_LITERAL" => "year" - case "TOK_INTERVAL_MONTH_LITERAL" => "month" - case "TOK_INTERVAL_WEEK_LITERAL" => "week" - case "TOK_INTERVAL_DAY_LITERAL" => "day" - case "TOK_INTERVAL_HOUR_LITERAL" => "hour" - case "TOK_INTERVAL_MINUTE_LITERAL" => "minute" - case "TOK_INTERVAL_SECOND_LITERAL" => "second" - case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond" - case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond" - case _ => noParseRule(s"Interval($name)", e) - } - interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value)) - updated = true - case _ => - } - if (!updated) { - throw new AnalysisException("at least one time unit should be given for interval literal") - } - Literal(interval) - - case _ => - noParseRule("Expression", node) - } - - /* Case insensitive matches for Window Specification */ - val PRECEDING = "(?i)preceding".r - val FOLLOWING = "(?i)following".r - val CURRENT = "(?i)current".r - protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { - case Token(windowName, Nil) :: Nil => - // Refer to a window spec defined in the window clause. - WindowSpecReference(windowName) - case Nil => - // OVER() - WindowSpecDefinition( - partitionSpec = Nil, - orderSpec = Nil, - frameSpecification = UnspecifiedFrame) - case spec => - val (partitionClause :: rowFrame :: rangeFrame :: Nil) = - getClauses( - Seq( - "TOK_PARTITIONINGSPEC", - "TOK_WINDOWRANGE", - "TOK_WINDOWVALUES"), - spec) - - // Handle Partition By and Order By. - val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => - val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = - getClauses( - Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.children) - - (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { - case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.children.map(nodeToExpr), - orderByExpr.children.map(nodeToSortOrder)) - case (Some(partitionByExpr), None, None) => - (partitionByExpr.children.map(nodeToExpr), Nil) - case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.children.map(nodeToSortOrder)) - case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.children.map(nodeToExpr) - (expressions, expressions.map(SortOrder(_, Ascending))) - case _ => - noParseRule("Partition & Ordering", partitionAndOrdering) - } - }.getOrElse { - (Nil, Nil) - } - - // Handle Window Frame - val windowFrame = - if (rowFrame.isEmpty && rangeFrame.isEmpty) { - UnspecifiedFrame - } else { - val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) - def nodeToBoundary(node: ASTNode): FrameBoundary = node match { - case Token(PRECEDING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedPreceding - } else { - ValuePreceding(count.toInt) - } - case Token(FOLLOWING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedFollowing - } else { - ValueFollowing(count.toInt) - } - case Token(CURRENT(), Nil) => CurrentRow - case _ => - noParseRule("Window Frame Boundary", node) - } - - rowFrame.orElse(rangeFrame).map { frame => - frame.children match { - case precedingNode :: followingNode :: Nil => - SpecifiedWindowFrame( - frameType, - nodeToBoundary(precedingNode), - nodeToBoundary(followingNode)) - case precedingNode :: Nil => - SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) - case _ => - noParseRule("Window Frame", frame) - } - }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) - } - - WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) - } - - protected def nodeToTransformation( - node: ASTNode, - child: LogicalPlan): Option[ScriptTransformation] = None - - val explode = "(?i)explode".r - val jsonTuple = "(?i)json_tuple".r - protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { - val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node - - val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text) - - val generator = clauses.head match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => - Explode(nodeToExpr(childNode)) - case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => - JsonTuple(children.map(nodeToExpr)) - case other => - nodeToGenerator(other) - } - - val attributes = clauses.collect { - case Token(a, Nil) => UnresolvedAttribute(cleanIdentifier(a.toLowerCase)) - } - - Generate( - generator, - join = true, - outer = outer, - Some(cleanIdentifier(alias.toLowerCase)), - attributes, - child) - } - - protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node) - -}
http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala index 21deb82..0b570c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser import scala.language.implicitConversions import scala.util.matching.Regex import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.input.CharArrayReader._ import org.apache.spark.sql.types._ @@ -117,3 +118,69 @@ private[sql] object DataTypeParser { /** The exception thrown from the [[DataTypeParser]]. */ private[sql] class DataTypeException(message: String) extends Exception(message) + +class SqlLexical extends scala.util.parsing.combinator.lexical.StdLexical { + case class DecimalLit(chars: String) extends Token { + override def toString: String = chars + } + + /* This is a work around to support the lazy setting */ + def initialize(keywords: Seq[String]): Unit = { + reserved.clear() + reserved ++= keywords + } + + /* Normal the keyword string */ + def normalizeKeyword(str: String): String = str.toLowerCase + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" + ) + + protected override def processIdent(name: String) = { + val token = normalizeKeyword(name) + if (reserved contains token) Keyword(token) else Identifier(name) + } + + override lazy val token: Parser[Token] = + ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } + | '.' ~> (rep1(digit) ~ scientificNotation) ^^ + { case i ~ s => DecimalLit("0." + i.mkString + s) } + | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ + { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } + | digit.* ~ identChar ~ (identChar | digit).* ^^ + { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ + { case chars => Identifier(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar: Parser[Elem] = letter | elem('_') + + private lazy val scientificNotation: Parser[String] = + (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { + case s ~ rest => "e" + s.mkString + rest.mkString + } + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* +} http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 51cfc50..d013252 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -16,91 +16,106 @@ */ package org.apache.spark.sql.catalyst.parser -import scala.annotation.tailrec - -import org.antlr.runtime._ -import org.antlr.runtime.tree.CommonTree +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType /** - * The ParseDriver takes a SQL command and turns this into an AST. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver + * Base SQL parsing infrastructure. */ -object ParseDriver extends Logging { - /** Create an LogicalPlan ASTNode from a SQL command. */ - def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.statement().getTree - } +abstract class AbstractSqlParser extends ParserInterface with Logging { - /** Create an Expression ASTNode from a SQL command. */ - def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleNamedExpression().getTree + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) } - /** Create an TableIdentifier ASTNode from a SQL command. */ - def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleTableName().getTree + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) } - private def parse( - command: String, - conf: ParserConf)( - toTree: SparkSqlParser => CommonTree): ASTNode = { - logInfo(s"Parsing command: $command") + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } - // Setup error collection. - val reporter = new ParseErrorReporter() + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } - // Create lexer. - val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command)) - val tokens = new TokenRewriteStream(lexer) - lexer.configure(conf, reporter) + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder - // Create the parser. - val parser = new SparkSqlParser(tokens) - parser.configure(conf, reporter) + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } - try { - val result = toTree(parser) + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") - // Check errors. - reporter.checkForErrors() + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) - // Return the AST node from the result. - logInfo(s"Parse completed.") + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) - // Find the non null token tree in the result. - @tailrec - def nonNullToken(tree: CommonTree): CommonTree = { - if (tree.token != null || tree.getChildCount == 0) tree - else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) } - val tree = nonNullToken(result) - - // Make sure all boundaries are set. - tree.setUnknownTokenBoundaries() - - // Construct the immutable AST. - def createASTNode(tree: CommonTree): ASTNode = { - val children = (0 until tree.getChildCount).map { i => - createASTNode(tree.getChild(i).asInstanceOf[CommonTree]) - }.toList - ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens) + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) } - createASTNode(tree) } catch { - case e: RecognitionException => - logInfo(s"Parse failed.") - reporter.throwError(e) + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) } } } /** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + +/** * This string stream provides the lexer with upper case characters only. This greatly simplifies * lexing the stream, while we can maintain the original command. * @@ -120,58 +135,104 @@ object ParseDriver extends Logging { * have the ANTLRNoCaseStringStream implementation. */ -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) { +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { override def LA(i: Int): Int = { val la = super.LA(i) - if (la == 0 || la == CharStream.EOF) la + if (la == 0 || la == IntStream.EOF) la else Character.toUpperCase(la) } } /** - * Utility used by the Parser and the Lexer for error collection and reporting. + * The ParseErrorListener converts parse errors into AnalysisExceptions. */ -private[parser] class ParseErrorReporter { - val errors = scala.collection.mutable.Buffer.empty[ParseError] - - def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = { - errors += ParseError(br, re, tokenNames) +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) } +} - def checkForErrors(): Unit = { - if (errors.nonEmpty) { - val first = errors.head - val e = first.re - throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail) - } +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) } - def throwError(e: RecognitionException): Nothing = { - throwError(e.line, e.charPositionInLine, e.toString, errors) + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString } - private def throwError( - line: Int, - startPosition: Int, - msg: String, - errors: Seq[ParseError]): Nothing = { - val b = new StringBuilder - b.append(msg).append("\n") - errors.foreach(error => error.buildMessage(b).append("\n")) - throw new AnalysisException(b.toString, Option(line), Option(startPosition)) + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) } } /** - * Error collected during the parsing process. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError + * The post-processor validates & cleans-up the parse tree during the parse process. */ -private[parser] case class ParseError( - br: BaseRecognizer, - re: RecognitionException, - tokenNames: Array[String]) { - def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = { - s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames)) +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) } } http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala deleted file mode 100644 index ce449b1..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser - -trait ParserConf { - def supportQuotedId: Boolean - def supportSQL11ReservedKeywords: Boolean -} - -case class SimpleParserConf( - supportQuotedId: Boolean = true, - supportSQL11ReservedKeywords: Boolean = false) extends ParserConf http://git-wip-us.apache.org/repos/asf/spark/blob/a9b93e07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 0c2e481..90b76dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -14,166 +14,105 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.types._ +import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode +import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} /** - * A collection of utility methods and patterns for parsing query texts. + * A collection of utility methods for use during the parsing process. */ -// TODO: merge with ParseUtils object ParserUtils { - - object Token { - // Match on (text, children) - def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { - CurrentOrigin.setPosition(node.line, node.positionInLine) - node.pattern - } + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + command(ctx.getStart.getInputStream) } - private val escapedIdentifier = "`(.+)`".r - private val doubleQuotedString = "\"([^\"]+)\"".r - private val singleQuotedString = "'([^']+)'".r - - // Token patterns - val COUNT = "(?i)COUNT".r - val SUM = "(?i)SUM".r - val AND = "(?i)AND".r - val OR = "(?i)OR".r - val NOT = "(?i)NOT".r - val TRUE = "(?i)TRUE".r - val FALSE = "(?i)FALSE".r - val LIKE = "(?i)LIKE".r - val RLIKE = "(?i)RLIKE".r - val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r - val DIV = "(?i)DIV".r - val BETWEEN = "(?i)BETWEEN".r - val WHEN = "(?i)WHEN".r - val CASE = "(?i)CASE".r - val INTEGRAL = "[+-]?\\d+".r - val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r - - /** - * Strip quotes, if any, from the string. - */ - def unquoteString(str: String): String = { - str match { - case singleQuotedString(s) => s - case doubleQuotedString(s) => s - case other => other - } + /** Get the command which created the token. */ + def command(stream: CharStream): String = { + stream.getText(Interval.of(0, stream.size())) } - /** - * Strip backticks, if any, from the string. - */ - def cleanIdentifier(ident: String): String = { - ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent - } + /** Get the code that creates the given node. */ + def source(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) } - def getClauses( - clauseNames: Seq[String], - nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } + /** Get all the text which comes after the given rule. */ + def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses + /** Get all the text which comes after the given token. */ + def remainder(token: Token): String = { + val stream = token.getInputStream + val interval = Interval.of(token.getStopIndex + 1, stream.size()) + stream.getText(interval) } - def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = { - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}")) - } + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) - def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = { - nodeList.filter { case ast: ASTNode => ast.text == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) - def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = { - tableNameParts.children.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => TableIdentifier(tableOnly) - case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + Origin(Option(token.getLine), Option(token.getCharPositionInLine)) } - def nodeToDataType(node: ASTNode): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.text.toInt, scale.text.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.text.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", keyType :: valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case _ => - noParseRule("DataType", node) - } - - def nodeToStructField(node: ASTNode): StructField = node match { - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => - StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) => - val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build() - StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta) - case _ => - noParseRule("StructField", node) + /** Assert if a condition holds. If it doesn't throw a parse exception. */ + def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + if (!f) { + throw new ParseException(message, ctx) + } } /** - * Throw an exception because we cannot parse the given node for some unexpected reason. + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. */ - def parseFailed(msg: String, node: ASTNode): Nothing = { - throw new AnalysisException(s"$msg: '${node.source}") + def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } } - /** - * Throw an exception because there are no rules to parse the node. - */ - def noParseRule(msg: String, node: ASTNode): Nothing = { - throw new NotImplementedError( - s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}") - } + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ + implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { + /** + * Create a plan using the block of code when the given context exists. Otherwise return the + * original plan. + */ + def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f + } else { + plan + } + } + /** + * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the + * passed function. The original plan is returned when the context does not exist. + */ + def optionalMap[C <: ParserRuleContext]( + ctx: C)( + f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f(ctx, plan) + } else { + plan + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
