This is an automated email from the ASF dual-hosted git repository. zjffdu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/zeppelin.git
The following commit(s) were added to refs/heads/master by this push: new 737d162 [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter 737d162 is described below commit 737d1626d073351dca3c3cc508d0dbea773c4e43 Author: Jeff Zhang <zjf...@apache.org> AuthorDate: Sun Dec 29 23:31:19 2019 +0800 [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter ### What is this PR for? Use the SqlSplitter in `zeppelin-interpreter` to split sql and execute in SparkSqlInterpreter. Nothing changes for the previous single sql statement paragraph. But just multiple result will be displayed for multiple sql statements. ### What type of PR is it? [Feature] ### Todos * [ ] - Task ### What is the Jira issue? * https://issues.apache.org/jira/browse/ZEPPELIN-4522 ### How should this be tested? * CI pass ### Screenshots (if appropriate) ### Questions: * Does the licenses files need update? No * Is there breaking changes for older versions? No * Does this needs documentation? No Author: Jeff Zhang <zjf...@apache.org> Closes #3579 from zjffdu/ZEPPELIN-4522 and squashes the following commits: eda573649 [Jeff Zhang] fix failed test 68d5a30c8 [Jeff Zhang] Add test for no sql but just 2 comments 4ff15e4fb [Jeff Zhang] address comment bc3c1feff [Jeff Zhang] [ZEPPELIN-4522]. Support multiple sql statements for SparkSqlInterpreter --- .../zeppelin/python/IPythonInterpreterTest.java | 6 ++-- .../apache/zeppelin/spark/SparkSqlInterpreter.java | 42 ++++++++++++++-------- .../zeppelin/spark/SparkSqlInterpreterTest.java | 41 +++++++++++++++++++++ .../org/apache/zeppelin/spark/Spark1Shims.java | 2 +- .../org/apache/zeppelin/spark/Spark2Shims.java | 2 +- .../zeppelin/interpreter/InterpreterOutput.java | 7 ++++ .../interpreter/InterpreterResultTest.java | 2 +- 7 files changed, 82 insertions(+), 20 deletions(-) diff --git a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java index b0a8ba6..0f302e9 100644 --- a/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/IPythonInterpreterTest.java @@ -298,13 +298,13 @@ public class IPythonInterpreterTest extends BasePythonInterpreterTest { "df.hvplot()", context); assertEquals(InterpreterResult.Code.SUCCESS, result.code()); interpreterResultMessages = context.out.toInterpreterResultMessage(); - assertEquals(5, interpreterResultMessages.size()); + assertEquals(4, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(0).getType()); assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(1).getType()); assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(2).getType()); assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(3).getType()); - assertEquals(InterpreterResult.Type.HTML, interpreterResultMessages.get(4).getType()); // docs_json is the source data of plotting which bokeh would use to render the plotting. - assertTrue(interpreterResultMessages.get(4).getData().contains("docs_json")); + assertTrue(interpreterResultMessages.get(3).getData().contains("docs_json")); } diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java index 4e63760..f6372dd 100644 --- a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java @@ -28,6 +28,7 @@ import org.apache.zeppelin.interpreter.InterpreterException; import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResult.Code; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.interpreter.util.SqlSplitter; import org.apache.zeppelin.scheduler.Scheduler; import org.apache.zeppelin.scheduler.SchedulerFactory; import org.slf4j.Logger; @@ -44,6 +45,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter { private Logger logger = LoggerFactory.getLogger(SparkSqlInterpreter.class); private SparkInterpreter sparkInterpreter; + private SqlSplitter sqlSplitter; public SparkSqlInterpreter(Properties property) { super(property); @@ -52,6 +54,7 @@ public class SparkSqlInterpreter extends AbstractInterpreter { @Override public void open() throws InterpreterException { this.sparkInterpreter = getInterpreterInTheSameSessionByClassName(SparkInterpreter.class); + this.sqlSplitter = new SqlSplitter(); } public boolean concurrentSQL() { @@ -82,26 +85,37 @@ public class SparkSqlInterpreter extends AbstractInterpreter { sparkInterpreter.getZeppelinContext().setInterpreterContext(context); SQLContext sqlc = sparkInterpreter.getSQLContext(); SparkContext sc = sqlc.sparkContext(); + + StringBuilder builder = new StringBuilder(); + List<String> sqls = sqlSplitter.splitSql(st); + int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit", + "" + sparkInterpreter.getZeppelinContext().getMaxResult())); + sc.setLocalProperty("spark.scheduler.pool", context.getLocalProperties().get("pool")); sc.setJobGroup(Utils.buildJobGroupId(context), Utils.buildJobDesc(context), false); - + String curSql = null; try { - Method method = sqlc.getClass().getMethod("sql", String.class); - int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit", - "" + sparkInterpreter.getZeppelinContext().getMaxResult())); - String msg = sparkInterpreter.getZeppelinContext().showData( - method.invoke(sqlc, st), maxResult); - sc.clearJobGroup(); - return new InterpreterResult(Code.SUCCESS, msg); + for (String sql : sqls) { + curSql = sql; + String result = sparkInterpreter.getZeppelinContext().showData(sqlc.sql(sql), maxResult); + builder.append(result); + } } catch (Exception e) { - if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) { - return new InterpreterResult(Code.ERROR, ExceptionUtils.getStackTrace(e)); + builder.append("\n%text Error happens in sql: " + curSql + "\n"); + if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace", "false"))) { + builder.append(ExceptionUtils.getStackTrace(e)); + } else { + logger.error("Invocation target exception", e); + String msg = e.getCause().getMessage() + + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace"; + builder.append(msg); } - logger.error("Invocation target exception", e); - String msg = e.getCause().getMessage() - + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace"; - return new InterpreterResult(Code.ERROR, msg); + return new InterpreterResult(Code.ERROR, builder.toString()); + } finally { + sc.clearJobGroup(); } + + return new InterpreterResult(Code.SUCCESS, builder.toString()); } @Override diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java index cab5b1b..c3f245b 100644 --- a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkSqlInterpreterTest.java @@ -171,6 +171,47 @@ public class SparkSqlInterpreterTest { } @Test + public void testMultipleStatements() throws InterpreterException { + sparkInterpreter.interpret("case class P(age:Int)", context); + sparkInterpreter.interpret( + "val gr = sc.parallelize(Seq(P(1),P(2),P(3),P(4)))", + context); + sparkInterpreter.interpret("gr.toDF.registerTempTable(\"gr\")", context); + + // Two correct sql + InterpreterResult ret = sqlInterpreter.interpret( + "select * --comment_1\nfrom gr;select count(1) from gr", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(ret.message().toString(), 2, ret.message().size()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(0).getType()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(1).getType()); + + // One correct sql + One invalid sql + ret = sqlInterpreter.interpret("select * from gr;invalid_sql", context); + assertEquals(InterpreterResult.Code.ERROR, ret.code()); + assertEquals(ret.message().toString(), 2, ret.message().size()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(0).getType()); + if (sparkInterpreter.getSparkVersion().isSpark2()) { + assertTrue(ret.message().toString(), ret.message().get(1).getData().contains("ParseException")); + } + + // One correct sql + One invalid sql + One valid sql (skipped) + ret = sqlInterpreter.interpret("select * from gr;invalid_sql; select count(1) from gr", context); + assertEquals(InterpreterResult.Code.ERROR, ret.code()); + assertEquals(ret.message().toString(), 2, ret.message().size()); + assertEquals(ret.message().toString(), Type.TABLE, ret.message().get(0).getType()); + if (sparkInterpreter.getSparkVersion().isSpark2()) { + assertTrue(ret.message().toString(), ret.message().get(1).getData().contains("ParseException")); + } + + // Two 2 comments + ret = sqlInterpreter.interpret( + "--comment_1\n--comment_2", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(ret.message().toString(), 0, ret.message().size()); + } + + @Test public void testConcurrentSQL() throws InterpreterException, InterruptedException { if (sparkInterpreter.getSparkVersion().isSpark2()) { sparkInterpreter.interpret("spark.udf.register(\"sleep\", (e:Int) => {Thread.sleep(e*1000); e})", context); diff --git a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java index 8e60ed0..6119647 100644 --- a/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java +++ b/spark/spark1-shims/src/main/scala/org/apache/zeppelin/spark/Spark1Shims.java @@ -70,7 +70,7 @@ public class Spark1Shims extends SparkShims { // fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult List<Row> rows = df.takeAsList(maxResult + 1); StringBuilder msg = new StringBuilder(); - msg.append("%table "); + msg.append("\n%table "); msg.append(StringUtils.join(columns, "\t")); msg.append("\n"); boolean isLargerThanMaxResult = rows.size() > maxResult; diff --git a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java index a7304c5..b7b1cf9 100644 --- a/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java +++ b/spark/spark2-shims/src/main/scala/org/apache/zeppelin/spark/Spark2Shims.java @@ -71,7 +71,7 @@ public class Spark2Shims extends SparkShims { // fetch maxResult+1 rows so that we can check whether it is larger than zeppelin.spark.maxResult List<Row> rows = df.takeAsList(maxResult + 1); StringBuilder msg = new StringBuilder(); - msg.append("%table "); + msg.append("\n%table "); msg.append(StringUtils.join(columns, "\t")); msg.append("\n"); boolean isLargerThanMaxResult = rows.size() > maxResult; diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java index 8853227..f85e535 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/interpreter/InterpreterOutput.java @@ -17,6 +17,7 @@ package org.apache.zeppelin.interpreter; +import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -328,6 +329,12 @@ public class InterpreterOutput extends OutputStream { List<InterpreterResultMessage> list = new LinkedList<>(); synchronized (resultMessageOutputs) { for (InterpreterResultMessageOutput out : resultMessageOutputs) { + if (out.toInterpreterResultMessage().getType() == InterpreterResult.Type.TEXT && + StringUtils.isBlank(out.toInterpreterResultMessage().getData())) { + // skip blank text, because when print table data we usually need to print '%text \n' + // first to separate it from previous other kind of data. e.g. z.show(df) + continue; + } list.add(out.toInterpreterResultMessage()); } } diff --git a/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java b/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java index a8ff1bf..84805ac 100644 --- a/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java +++ b/zeppelin-interpreter/src/test/java/org/apache/zeppelin/interpreter/InterpreterResultTest.java @@ -33,7 +33,7 @@ public class InterpreterResultTest { result = new InterpreterResult(InterpreterResult.Code.SUCCESS, "%this is a TEXT type"); assertEquals("No magic", InterpreterResult.Type.TEXT, result.message().get(0).getType()); result = new InterpreterResult(InterpreterResult.Code.SUCCESS, "%\n"); - assertEquals("No magic", InterpreterResult.Type.TEXT, result.message().get(0).getType()); + assertEquals(0, result.message().size()); } @Test