http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java new file mode 100644 index 0000000..cfcf2a5 --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkInterpreterTest.java @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.display.ui.CheckBox; +import org.apache.zeppelin.display.ui.Select; +import org.apache.zeppelin.display.ui.TextBox; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterOutputListener; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.After; +import org.junit.Test; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.net.URL; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.HashMap; +import java.util.List; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + + +public class NewSparkInterpreterTest { + + private SparkInterpreter interpreter; + + // catch the streaming output in onAppend + private volatile String output = ""; + // catch the interpreter output in onUpdate + private InterpreterResultMessageOutput messageOutput; + + @Test + public void testSparkInterpreter() throws IOException, InterruptedException, InterpreterException { + Properties properties = new Properties(); + properties.setProperty("spark.master", "local"); + properties.setProperty("spark.app.name", "test"); + properties.setProperty("zeppelin.spark.maxResult", "100"); + properties.setProperty("zeppelin.spark.test", "true"); + properties.setProperty("zeppelin.spark.useNew", "true"); + interpreter = new SparkInterpreter(properties); + assertTrue(interpreter.getDelegation() instanceof NewSparkInterpreter); + interpreter.setInterpreterGroup(mock(InterpreterGroup.class)); + interpreter.open(); + + InterpreterResult result = interpreter.interpret("val a=\"hello world\"", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals("a: String = hello world\n", output); + + result = interpreter.interpret("print(a)", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals("hello world", output); + + // incomplete + result = interpreter.interpret("println(a", getInterpreterContext()); + assertEquals(InterpreterResult.Code.INCOMPLETE, result.code()); + + // syntax error + result = interpreter.interpret("println(b)", getInterpreterContext()); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + assertTrue(output.contains("not found: value b")); + + // multiple line + result = interpreter.interpret("\"123\".\ntoInt", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // single line comment + result = interpreter.interpret("/*comment here*/", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + result = interpreter.interpret("/*comment here*/\nprint(\"hello world\")", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // multiple line comment + result = interpreter.interpret("/*line 1 \n line 2*/", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // test function + result = interpreter.interpret("def add(x:Int, y:Int)\n{ return x+y }", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + result = interpreter.interpret("print(add(1,2))", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + result = interpreter.interpret("/*line 1 \n line 2*/print(\"hello world\")", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // companion object + result = interpreter.interpret("class Counter {\n " + + "var value: Long = 0} \n" + + "object Counter {\n def apply(x: Long) = new Counter()\n}", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // spark rdd operation + result = interpreter.interpret("sc.range(1, 10).sum", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertTrue(output.contains("45")); + + // case class + result = interpreter.interpret("val bankText = sc.textFile(\"bank.csv\")", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + result = interpreter.interpret( + "case class Bank(age:Integer, job:String, marital : String, education : String, balance : Integer)\n", + getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + result = interpreter.interpret( + "val bank = bankText.map(s=>s.split(\";\")).filter(s => s(0)!=\"\\\"age\\\"\").map(\n" + + " s => Bank(s(0).toInt, \n" + + " s(1).replaceAll(\"\\\"\", \"\"),\n" + + " s(2).replaceAll(\"\\\"\", \"\"),\n" + + " s(3).replaceAll(\"\\\"\", \"\"),\n" + + " s(5).replaceAll(\"\\\"\", \"\").toInt\n" + + " )\n" + + ")", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // spark version + result = interpreter.interpret("sc.version", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + // spark sql test + String version = output.trim(); + if (version.contains("String = 1.")) { + result = interpreter.interpret("sqlContext", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + result = interpreter.interpret( + "val df = sqlContext.createDataFrame(Seq((1,\"a\"),(2,\"b\")))\n" + + "df.show()", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertTrue(output.contains( + "+---+---+\n" + + "| _1| _2|\n" + + "+---+---+\n" + + "| 1| a|\n" + + "| 2| b|\n" + + "+---+---+")); + } else if (version.contains("String = 2.")) { + result = interpreter.interpret("spark", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + result = interpreter.interpret( + "val df = spark.createDataFrame(Seq((1,\"a\"),(2,\"b\")))\n" + + "df.show()", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertTrue(output.contains( + "+---+---+\n" + + "| _1| _2|\n" + + "+---+---+\n" + + "| 1| a|\n" + + "| 2| b|\n" + + "+---+---+")); + } + + // ZeppelinContext + result = interpreter.interpret("z.show(df)", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(InterpreterResult.Type.TABLE, messageOutput.getType()); + messageOutput.flush(); + assertEquals("_1\t_2\n1\ta\n2\tb\n", messageOutput.toInterpreterResultMessage().getData()); + + InterpreterContext context = getInterpreterContext(); + result = interpreter.interpret("z.input(\"name\", \"default_name\")", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("name") instanceof TextBox); + TextBox textBox = (TextBox) context.getGui().getForms().get("name"); + assertEquals("name", textBox.getName()); + assertEquals("default_name", textBox.getDefaultValue()); + + context = getInterpreterContext(); + result = interpreter.interpret("z.checkbox(\"checkbox_1\", Seq(\"value_2\"), Seq((\"value_1\", \"name_1\"), (\"value_2\", \"name_2\")))", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("checkbox_1") instanceof CheckBox); + CheckBox checkBox = (CheckBox) context.getGui().getForms().get("checkbox_1"); + assertEquals("checkbox_1", checkBox.getName()); + assertEquals(1, checkBox.getDefaultValue().length); + assertEquals("value_2", checkBox.getDefaultValue()[0]); + assertEquals(2, checkBox.getOptions().length); + assertEquals("value_1", checkBox.getOptions()[0].getValue()); + assertEquals("name_1", checkBox.getOptions()[0].getDisplayName()); + assertEquals("value_2", checkBox.getOptions()[1].getValue()); + assertEquals("name_2", checkBox.getOptions()[1].getDisplayName()); + + context = getInterpreterContext(); + result = interpreter.interpret("z.select(\"select_1\", Seq(\"value_2\"), Seq((\"value_1\", \"name_1\"), (\"value_2\", \"name_2\")))", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(1, context.getGui().getForms().size()); + assertTrue(context.getGui().getForms().get("select_1") instanceof Select); + Select select = (Select) context.getGui().getForms().get("select_1"); + assertEquals("select_1", select.getName()); + // TODO(zjffdu) it seems a bug of GUI, the default value should be 'value_2', but it is List(value_2) + // assertEquals("value_2", select.getDefaultValue()); + assertEquals(2, select.getOptions().length); + assertEquals("value_1", select.getOptions()[0].getValue()); + assertEquals("name_1", select.getOptions()[0].getDisplayName()); + assertEquals("value_2", select.getOptions()[1].getValue()); + assertEquals("name_2", select.getOptions()[1].getDisplayName()); + + + // completions + List<InterpreterCompletion> completions = interpreter.completion("a.", 2, getInterpreterContext()); + assertTrue(completions.size() > 0); + + completions = interpreter.completion("a.isEm", 6, getInterpreterContext()); + assertEquals(1, completions.size()); + assertEquals("isEmpty", completions.get(0).name); + + completions = interpreter.completion("sc.ra", 5, getInterpreterContext()); + assertEquals(1, completions.size()); + assertEquals("range", completions.get(0).name); + + + // Zeppelin-Display + result = interpreter.interpret("import org.apache.zeppelin.display.angular.notebookscope._\n" + + "import AngularElem._", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + result = interpreter.interpret("<div style=\"color:blue\">\n" + + "<h4>Hello Angular Display System</h4>\n" + + "</div>.display", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(InterpreterResult.Type.ANGULAR, messageOutput.getType()); + assertTrue(messageOutput.toInterpreterResultMessage().getData().contains("Hello Angular Display System")); + + result = interpreter.interpret("<div class=\"btn btn-success\">\n" + + " Click me\n" + + "</div>.onClick{() =>\n" + + " println(\"hello world\")\n" + + "}.display", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(InterpreterResult.Type.ANGULAR, messageOutput.getType()); + assertTrue(messageOutput.toInterpreterResultMessage().getData().contains("Click me")); + + // getProgress + final InterpreterContext context2 = getInterpreterContext(); + Thread interpretThread = new Thread() { + @Override + public void run() { + InterpreterResult result = null; + try { + result = interpreter.interpret( + "val df = sc.parallelize(1 to 10, 2).foreach(e=>Thread.sleep(1000))", context2); + } catch (InterpreterException e) { + e.printStackTrace(); + } + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + } + }; + interpretThread.start(); + boolean nonZeroProgress = false; + int progress = 0; + while(interpretThread.isAlive()) { + progress = interpreter.getProgress(context2); + assertTrue(progress >= 0); + if (progress != 0 && progress != 100) { + nonZeroProgress = true; + } + Thread.sleep(100); + } + assertTrue(nonZeroProgress); + + // cancel + final InterpreterContext context3 = getInterpreterContext(); + interpretThread = new Thread() { + @Override + public void run() { + InterpreterResult result = null; + try { + result = interpreter.interpret( + "val df = sc.parallelize(1 to 10, 2).foreach(e=>Thread.sleep(1000))", context3); + } catch (InterpreterException e) { + e.printStackTrace(); + } + assertEquals(InterpreterResult.Code.ERROR, result.code()); + assertTrue(output.contains("cancelled")); + } + }; + + interpretThread.start(); + // sleep 1 second to wait for the spark job start + Thread.sleep(1000); + interpreter.cancel(context3); + interpretThread.join(); + } + + @Test + public void testDependencies() throws IOException, InterpreterException { + Properties properties = new Properties(); + properties.setProperty("spark.master", "local"); + properties.setProperty("spark.app.name", "test"); + properties.setProperty("zeppelin.spark.maxResult", "100"); + properties.setProperty("zeppelin.spark.useNew", "true"); + + // download spark-avro jar + URL website = new URL("http://repo1.maven.org/maven2/com/databricks/spark-avro_2.11/3.2.0/spark-avro_2.11-3.2.0.jar"); + ReadableByteChannel rbc = Channels.newChannel(website.openStream()); + File avroJarFile = new File("spark-avro_2.11-3.2.0.jar"); + FileOutputStream fos = new FileOutputStream(avroJarFile); + fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); + + properties.setProperty("spark.jars", avroJarFile.getAbsolutePath()); + + interpreter = new SparkInterpreter(properties); + assertTrue(interpreter.getDelegation() instanceof NewSparkInterpreter); + interpreter.setInterpreterGroup(mock(InterpreterGroup.class)); + interpreter.open(); + + InterpreterResult result = interpreter.interpret("import com.databricks.spark.avro._", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + } + + @After + public void tearDown() throws InterpreterException { + if (this.interpreter != null) { + this.interpreter.close(); + } + } + + private InterpreterContext getInterpreterContext() { + output = ""; + return new InterpreterContext( + "noteId", + "paragraphId", + "replName", + "paragraphTitle", + "paragraphText", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + new AngularObjectRegistry("spark", null), + null, + null, + new InterpreterOutput( + + new InterpreterOutputListener() { + @Override + public void onUpdateAll(InterpreterOutput out) { + + } + + @Override + public void onAppend(int index, InterpreterResultMessageOutput out, byte[] line) { + try { + output = out.toInterpreterResultMessage().getData(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Override + public void onUpdate(int index, InterpreterResultMessageOutput out) { + messageOutput = out; + } + }) + ); + } +}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkSqlInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkSqlInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkSqlInterpreterTest.java new file mode 100644 index 0000000..42289ff --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/NewSparkSqlInterpreterTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Properties; + +import com.google.common.io.Files; +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.InterpreterResult.Type; +import org.junit.*; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class NewSparkSqlInterpreterTest { + + private static SparkSqlInterpreter sqlInterpreter; + private static SparkInterpreter sparkInterpreter; + private static InterpreterContext context; + private static InterpreterGroup intpGroup; + + @BeforeClass + public static void setUp() throws Exception { + Properties p = new Properties(); + p.setProperty("spark.master", "local"); + p.setProperty("spark.app.name", "test"); + p.setProperty("zeppelin.spark.maxResult", "10"); + p.setProperty("zeppelin.spark.concurrentSQL", "false"); + p.setProperty("zeppelin.spark.sqlInterpreter.stacktrace", "false"); + p.setProperty("zeppelin.spark.useNew", "true"); + intpGroup = new InterpreterGroup(); + sparkInterpreter = new SparkInterpreter(p); + sparkInterpreter.setInterpreterGroup(intpGroup); + + sqlInterpreter = new SparkSqlInterpreter(p); + sqlInterpreter.setInterpreterGroup(intpGroup); + intpGroup.put("session_1", new LinkedList<Interpreter>()); + intpGroup.get("session_1").add(sparkInterpreter); + intpGroup.get("session_1").add(sqlInterpreter); + + sparkInterpreter.open(); + sqlInterpreter.open(); + + context = new InterpreterContext("note", "id", null, "title", "text", new AuthenticationInfo(), + new HashMap<String, Object>(), new GUI(), new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + new LocalResourcePool("id"), + new LinkedList<InterpreterContextRunner>(), new InterpreterOutput(null)); + } + + @AfterClass + public static void tearDown() throws InterpreterException { + sqlInterpreter.close(); + sparkInterpreter.close(); + } + + boolean isDataFrameSupported() { + return sparkInterpreter.getSparkVersion().hasDataFrame(); + } + + @Test + public void test() throws InterpreterException { + sparkInterpreter.interpret("case class Test(name:String, age:Int)", context); + sparkInterpreter.interpret("val test = sc.parallelize(Seq(Test(\"moon\", 33), Test(\"jobs\", 51), Test(\"gates\", 51), Test(\"park\", 34)))", context); + if (isDataFrameSupported()) { + sparkInterpreter.interpret("test.toDF.registerTempTable(\"test\")", context); + } else { + sparkInterpreter.interpret("test.registerTempTable(\"test\")", context); + } + + InterpreterResult ret = sqlInterpreter.interpret("select name, age from test where age < 40", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(Type.TABLE, ret.message().get(0).getType()); + assertEquals("name\tage\nmoon\t33\npark\t34\n", ret.message().get(0).getData()); + + ret = sqlInterpreter.interpret("select wrong syntax", context); + assertEquals(InterpreterResult.Code.ERROR, ret.code()); + assertTrue(ret.message().get(0).getData().length() > 0); + + assertEquals(InterpreterResult.Code.SUCCESS, sqlInterpreter.interpret("select case when name='aa' then name else name end from test", context).code()); + } + + @Test + public void testStruct() throws InterpreterException { + sparkInterpreter.interpret("case class Person(name:String, age:Int)", context); + sparkInterpreter.interpret("case class People(group:String, person:Person)", context); + sparkInterpreter.interpret( + "val gr = sc.parallelize(Seq(People(\"g1\", Person(\"moon\",33)), People(\"g2\", Person(\"sun\",11))))", + context); + if (isDataFrameSupported()) { + sparkInterpreter.interpret("gr.toDF.registerTempTable(\"gr\")", context); + } else { + sparkInterpreter.interpret("gr.registerTempTable(\"gr\")", context); + } + + InterpreterResult ret = sqlInterpreter.interpret("select * from gr", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + } + + public void test_null_value_in_row() throws InterpreterException { + sparkInterpreter.interpret("import org.apache.spark.sql._", context); + if (isDataFrameSupported()) { + sparkInterpreter.interpret( + "import org.apache.spark.sql.types.{StructType,StructField,StringType,IntegerType}", + context); + } + sparkInterpreter.interpret( + "def toInt(s:String): Any = {try { s.trim().toInt} catch {case e:Exception => null}}", + context); + sparkInterpreter.interpret( + "val schema = StructType(Seq(StructField(\"name\", StringType, false),StructField(\"age\" , IntegerType, true),StructField(\"other\" , StringType, false)))", + context); + sparkInterpreter.interpret( + "val csv = sc.parallelize(Seq((\"jobs, 51, apple\"), (\"gates, , microsoft\")))", + context); + sparkInterpreter.interpret( + "val raw = csv.map(_.split(\",\")).map(p => Row(p(0),toInt(p(1)),p(2)))", + context); + if (isDataFrameSupported()) { + sparkInterpreter.interpret("val people = sqlContext.createDataFrame(raw, schema)", + context); + sparkInterpreter.interpret("people.toDF.registerTempTable(\"people\")", context); + } else { + sparkInterpreter.interpret("val people = sqlContext.applySchema(raw, schema)", + context); + sparkInterpreter.interpret("people.registerTempTable(\"people\")", context); + } + + InterpreterResult ret = sqlInterpreter.interpret( + "select name, age from people where name = 'gates'", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(Type.TABLE, ret.message().get(0).getType()); + assertEquals("name\tage\ngates\tnull\n", ret.message().get(0).getData()); + } + + @Test + public void testMaxResults() throws InterpreterException { + sparkInterpreter.interpret("case class P(age:Int)", context); + sparkInterpreter.interpret( + "val gr = sc.parallelize(Seq(P(1),P(2),P(3),P(4),P(5),P(6),P(7),P(8),P(9),P(10),P(11)))", + context); + if (isDataFrameSupported()) { + sparkInterpreter.interpret("gr.toDF.registerTempTable(\"gr\")", context); + } else { + sparkInterpreter.interpret("gr.registerTempTable(\"gr\")", context); + } + + InterpreterResult ret = sqlInterpreter.interpret("select * from gr", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertTrue(ret.message().get(1).getData().contains("alert-warning")); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java new file mode 100644 index 0000000..14214a2 --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkInterpreterTest.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterContextRunner; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.remote.RemoteEventClientWrapper; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.apache.zeppelin.resource.WellKnownResourceName; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.AfterClass; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runners.MethodSorters; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class OldSparkInterpreterTest { + + @ClassRule + public static TemporaryFolder tmpDir = new TemporaryFolder(); + + static SparkInterpreter repl; + static InterpreterGroup intpGroup; + static InterpreterContext context; + static Logger LOGGER = LoggerFactory.getLogger(OldSparkInterpreterTest.class); + static Map<String, Map<String, String>> paraIdToInfosMap = + new HashMap<>(); + + /** + * Get spark version number as a numerical value. + * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ... + */ + public static int getSparkVersionNumber(SparkInterpreter repl) { + if (repl == null) { + return 0; + } + + String[] split = repl.getSparkContext().version().split("\\."); + int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]); + return version; + } + + public static Properties getSparkTestProperties(TemporaryFolder tmpDir) throws IOException { + Properties p = new Properties(); + p.setProperty("master", "local[*]"); + p.setProperty("spark.app.name", "Zeppelin Test"); + p.setProperty("zeppelin.spark.useHiveContext", "true"); + p.setProperty("zeppelin.spark.maxResult", "1000"); + p.setProperty("zeppelin.spark.importImplicit", "true"); + p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath()); + p.setProperty("zeppelin.spark.property_1", "value_1"); + return p; + } + + @BeforeClass + public static void setUp() throws Exception { + intpGroup = new InterpreterGroup(); + intpGroup.put("note", new LinkedList<Interpreter>()); + repl = new SparkInterpreter(getSparkTestProperties(tmpDir)); + repl.setInterpreterGroup(intpGroup); + intpGroup.get("note").add(repl); + repl.open(); + + final RemoteEventClientWrapper remoteEventClientWrapper = new RemoteEventClientWrapper() { + + @Override + public void onParaInfosReceived(String noteId, String paragraphId, + Map<String, String> infos) { + if (infos != null) { + paraIdToInfosMap.put(paragraphId, infos); + } + } + + @Override + public void onMetaInfosReceived(Map<String, String> infos) { + } + }; + context = new InterpreterContext("note", "id", null, "title", "text", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + new LocalResourcePool("id"), + new LinkedList<InterpreterContextRunner>(), + new InterpreterOutput(null)) { + + @Override + public RemoteEventClientWrapper getClient() { + return remoteEventClientWrapper; + } + }; + // The first para interpretdr will set the Eventclient wrapper + //SparkInterpreter.interpret(String, InterpreterContext) -> + //SparkInterpreter.populateSparkWebUrl(InterpreterContext) -> + //ZeppelinContext.setEventClient(RemoteEventClientWrapper) + //running a dummy to ensure that we dont have any race conditions among tests + repl.interpret("sc", context); + } + + @AfterClass + public static void tearDown() throws InterpreterException { + repl.close(); + } + + @Test + public void testBasicIntp() throws InterpreterException { + assertEquals(InterpreterResult.Code.SUCCESS, + repl.interpret("val a = 1\nval b = 2", context).code()); + + // when interpret incomplete expression + InterpreterResult incomplete = repl.interpret("val a = \"\"\"", context); + assertEquals(InterpreterResult.Code.INCOMPLETE, incomplete.code()); + assertTrue(incomplete.message().get(0).getData().length() > 0); // expecting some error + // message + + /* + * assertEquals(1, repl.getValue("a")); assertEquals(2, repl.getValue("b")); + * repl.interpret("val ver = sc.version"); + * assertNotNull(repl.getValue("ver")); assertEquals("HELLO\n", + * repl.interpret("println(\"HELLO\")").message()); + */ + } + + @Test + public void testNonStandardSparkProperties() throws IOException, InterpreterException { + // throw NoSuchElementException if no such property is found + InterpreterResult result = repl.interpret("sc.getConf.get(\"property_1\")", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + } + + @Test + public void testNextLineInvocation() throws InterpreterException { + assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret("\"123\"\n.toInt", context).code()); + } + + @Test + public void testNextLineComments() throws InterpreterException { + assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret("\"123\"\n/*comment here\n*/.toInt", context).code()); + } + + @Test + public void testNextLineCompanionObject() throws InterpreterException { + String code = "class Counter {\nvar value: Long = 0\n}\n // comment\n\n object Counter {\n def apply(x: Long) = new Counter()\n}"; + assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret(code, context).code()); + } + + @Test + public void testEndWithComment() throws InterpreterException { + assertEquals(InterpreterResult.Code.SUCCESS, repl.interpret("val c=1\n//comment", context).code()); + } + + @Test + public void testListener() { + SparkContext sc = repl.getSparkContext(); + assertNotNull(OldSparkInterpreter.setupListeners(sc)); + } + + @Test + public void testCreateDataFrame() throws InterpreterException { + if (getSparkVersionNumber(repl) >= 13) { + repl.interpret("case class Person(name:String, age:Int)\n", context); + repl.interpret("val people = sc.parallelize(Seq(Person(\"moon\", 33), Person(\"jobs\", 51), Person(\"gates\", 51), Person(\"park\", 34)))\n", context); + repl.interpret("people.toDF.count", context); + assertEquals(new Long(4), context.getResourcePool().get( + context.getNoteId(), + context.getParagraphId(), + WellKnownResourceName.ZeppelinReplResult.toString()).get()); + } + } + + @Test + public void testZShow() throws InterpreterException { + String code = ""; + repl.interpret("case class Person(name:String, age:Int)\n", context); + repl.interpret("val people = sc.parallelize(Seq(Person(\"moon\", 33), Person(\"jobs\", 51), Person(\"gates\", 51), Person(\"park\", 34)))\n", context); + if (getSparkVersionNumber(repl) < 13) { + repl.interpret("people.registerTempTable(\"people\")", context); + code = "z.show(sqlc.sql(\"select * from people\"))"; + } else { + code = "z.show(people.toDF)"; + } + assertEquals(Code.SUCCESS, repl.interpret(code, context).code()); + } + + @Test + public void testSparkSql() throws IOException, InterpreterException { + repl.interpret("case class Person(name:String, age:Int)\n", context); + repl.interpret("val people = sc.parallelize(Seq(Person(\"moon\", 33), Person(\"jobs\", 51), Person(\"gates\", 51), Person(\"park\", 34)))\n", context); + assertEquals(Code.SUCCESS, repl.interpret("people.take(3)", context).code()); + + + if (getSparkVersionNumber(repl) <= 11) { // spark 1.2 or later does not allow create multiple + // SparkContext in the same jvm by default. + // create new interpreter + SparkInterpreter repl2 = new SparkInterpreter(getSparkTestProperties(tmpDir)); + repl2.setInterpreterGroup(intpGroup); + intpGroup.get("note").add(repl2); + repl2.open(); + + repl2.interpret("case class Man(name:String, age:Int)", context); + repl2.interpret("val man = sc.parallelize(Seq(Man(\"moon\", 33), Man(\"jobs\", 51), Man(\"gates\", 51), Man(\"park\", 34)))", context); + assertEquals(Code.SUCCESS, repl2.interpret("man.take(3)", context).code()); + repl2.close(); + } + } + + @Test + public void testReferencingUndefinedVal() throws InterpreterException { + InterpreterResult result = repl.interpret("def category(min: Int) = {" + + " if (0 <= value) \"error\"" + "}", context); + assertEquals(Code.ERROR, result.code()); + } + + @Test + public void emptyConfigurationVariablesOnlyForNonSparkProperties() { + Properties intpProperty = repl.getProperties(); + SparkConf sparkConf = repl.getSparkContext().getConf(); + for (Object oKey : intpProperty.keySet()) { + String key = (String) oKey; + String value = (String) intpProperty.get(key); + LOGGER.debug(String.format("[%s]: [%s]", key, value)); + if (key.startsWith("spark.") && value.isEmpty()) { + assertTrue(String.format("configuration starting from 'spark.' should not be empty. [%s]", key), !sparkConf.contains(key) || !sparkConf.get(key).isEmpty()); + } + } + } + + @Test + public void shareSingleSparkContext() throws InterruptedException, IOException, InterpreterException { + // create another SparkInterpreter + SparkInterpreter repl2 = new SparkInterpreter(getSparkTestProperties(tmpDir)); + repl2.setInterpreterGroup(intpGroup); + intpGroup.get("note").add(repl2); + repl2.open(); + + assertEquals(Code.SUCCESS, + repl.interpret("print(sc.parallelize(1 to 10).count())", context).code()); + assertEquals(Code.SUCCESS, + repl2.interpret("print(sc.parallelize(1 to 10).count())", context).code()); + + repl2.close(); + } + + @Test + public void testEnableImplicitImport() throws IOException, InterpreterException { + if (getSparkVersionNumber(repl) >= 13) { + // Set option of importing implicits to "true", and initialize new Spark repl + Properties p = getSparkTestProperties(tmpDir); + p.setProperty("zeppelin.spark.importImplicit", "true"); + SparkInterpreter repl2 = new SparkInterpreter(p); + repl2.setInterpreterGroup(intpGroup); + intpGroup.get("note").add(repl2); + + repl2.open(); + String ddl = "val df = Seq((1, true), (2, false)).toDF(\"num\", \"bool\")"; + assertEquals(Code.SUCCESS, repl2.interpret(ddl, context).code()); + repl2.close(); + } + } + + @Test + public void testDisableImplicitImport() throws IOException, InterpreterException { + if (getSparkVersionNumber(repl) >= 13) { + // Set option of importing implicits to "false", and initialize new Spark repl + // this test should return error status when creating DataFrame from sequence + Properties p = getSparkTestProperties(tmpDir); + p.setProperty("zeppelin.spark.importImplicit", "false"); + SparkInterpreter repl2 = new SparkInterpreter(p); + repl2.setInterpreterGroup(intpGroup); + intpGroup.get("note").add(repl2); + + repl2.open(); + String ddl = "val df = Seq((1, true), (2, false)).toDF(\"num\", \"bool\")"; + assertEquals(Code.ERROR, repl2.interpret(ddl, context).code()); + repl2.close(); + } + } + + @Test + public void testCompletion() throws InterpreterException { + List<InterpreterCompletion> completions = repl.completion("sc.", "sc.".length(), null); + assertTrue(completions.size() > 0); + } + + @Test + public void testMultilineCompletion() throws InterpreterException { + String buf = "val x = 1\nsc."; + List<InterpreterCompletion> completions = repl.completion(buf, buf.length(), null); + assertTrue(completions.size() > 0); + } + + @Test + public void testMultilineCompletionNewVar() throws InterpreterException { + Assume.assumeFalse("this feature does not work with scala 2.10", Utils.isScala2_10()); + Assume.assumeTrue("This feature does not work with scala < 2.11.8", Utils.isCompilerAboveScala2_11_7()); + String buf = "val x = sc\nx."; + List<InterpreterCompletion> completions = repl.completion(buf, buf.length(), null); + assertTrue(completions.size() > 0); + } + + @Test + public void testParagraphUrls() throws InterpreterException { + String paraId = "test_para_job_url"; + InterpreterContext intpCtx = new InterpreterContext("note", paraId, null, "title", "text", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + new LocalResourcePool("id"), + new LinkedList<InterpreterContextRunner>(), + new InterpreterOutput(null)); + repl.interpret("sc.parallelize(1 to 10).map(x => {x}).collect", intpCtx); + Map<String, String> paraInfos = paraIdToInfosMap.get(intpCtx.getParagraphId()); + String jobUrl = null; + if (paraInfos != null) { + jobUrl = paraInfos.get("jobUrl"); + } + String sparkUIUrl = repl.getSparkUIUrl(); + assertNotNull(jobUrl); + assertTrue(jobUrl.startsWith(sparkUIUrl + "/jobs/job/?id=")); + + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkSqlInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkSqlInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkSqlInterpreterTest.java new file mode 100644 index 0000000..d0b0874 --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/OldSparkSqlInterpreterTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterContextRunner; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Type; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class OldSparkSqlInterpreterTest { + + @ClassRule + public static TemporaryFolder tmpDir = new TemporaryFolder(); + + static SparkSqlInterpreter sql; + static SparkInterpreter repl; + static InterpreterContext context; + static InterpreterGroup intpGroup; + + @BeforeClass + public static void setUp() throws Exception { + Properties p = new Properties(); + p.putAll(OldSparkInterpreterTest.getSparkTestProperties(tmpDir)); + p.setProperty("zeppelin.spark.maxResult", "10"); + p.setProperty("zeppelin.spark.concurrentSQL", "false"); + p.setProperty("zeppelin.spark.sql.stacktrace", "false"); + + repl = new SparkInterpreter(p); + intpGroup = new InterpreterGroup(); + repl.setInterpreterGroup(intpGroup); + repl.open(); + OldSparkInterpreterTest.repl = repl; + OldSparkInterpreterTest.intpGroup = intpGroup; + + sql = new SparkSqlInterpreter(p); + + intpGroup = new InterpreterGroup(); + intpGroup.put("note", new LinkedList<Interpreter>()); + intpGroup.get("note").add(repl); + intpGroup.get("note").add(sql); + sql.setInterpreterGroup(intpGroup); + sql.open(); + + context = new InterpreterContext("note", "id", null, "title", "text", new AuthenticationInfo(), + new HashMap<String, Object>(), new GUI(), new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + new LocalResourcePool("id"), + new LinkedList<InterpreterContextRunner>(), new InterpreterOutput(null)); + } + + @AfterClass + public static void tearDown() throws InterpreterException { + sql.close(); + repl.close(); + } + + boolean isDataFrameSupported() { + return OldSparkInterpreterTest.getSparkVersionNumber(repl) >= 13; + } + + @Test + public void test() throws InterpreterException { + repl.interpret("case class Test(name:String, age:Int)", context); + repl.interpret("val test = sc.parallelize(Seq(Test(\"moon\", 33), Test(\"jobs\", 51), Test(\"gates\", 51), Test(\"park\", 34)))", context); + if (isDataFrameSupported()) { + repl.interpret("test.toDF.registerTempTable(\"test\")", context); + } else { + repl.interpret("test.registerTempTable(\"test\")", context); + } + + InterpreterResult ret = sql.interpret("select name, age from test where age < 40", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(Type.TABLE, ret.message().get(0).getType()); + assertEquals("name\tage\nmoon\t33\npark\t34\n", ret.message().get(0).getData()); + + ret = sql.interpret("select wrong syntax", context); + assertEquals(InterpreterResult.Code.ERROR, ret.code()); + assertTrue(ret.message().get(0).getData().length() > 0); + + assertEquals(InterpreterResult.Code.SUCCESS, sql.interpret("select case when name==\"aa\" then name else name end from test", context).code()); + } + + @Test + public void testStruct() throws InterpreterException { + repl.interpret("case class Person(name:String, age:Int)", context); + repl.interpret("case class People(group:String, person:Person)", context); + repl.interpret( + "val gr = sc.parallelize(Seq(People(\"g1\", Person(\"moon\",33)), People(\"g2\", Person(\"sun\",11))))", + context); + if (isDataFrameSupported()) { + repl.interpret("gr.toDF.registerTempTable(\"gr\")", context); + } else { + repl.interpret("gr.registerTempTable(\"gr\")", context); + } + + InterpreterResult ret = sql.interpret("select * from gr", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + } + + @Test + public void test_null_value_in_row() throws InterpreterException { + repl.interpret("import org.apache.spark.sql._", context); + if (isDataFrameSupported()) { + repl.interpret( + "import org.apache.spark.sql.types.{StructType,StructField,StringType,IntegerType}", + context); + } + repl.interpret( + "def toInt(s:String): Any = {try { s.trim().toInt} catch {case e:Exception => null}}", + context); + repl.interpret( + "val schema = StructType(Seq(StructField(\"name\", StringType, false),StructField(\"age\" , IntegerType, true),StructField(\"other\" , StringType, false)))", + context); + repl.interpret( + "val csv = sc.parallelize(Seq((\"jobs, 51, apple\"), (\"gates, , microsoft\")))", + context); + repl.interpret( + "val raw = csv.map(_.split(\",\")).map(p => Row(p(0),toInt(p(1)),p(2)))", + context); + if (isDataFrameSupported()) { + repl.interpret("val people = sqlContext.createDataFrame(raw, schema)", + context); + repl.interpret("people.toDF.registerTempTable(\"people\")", context); + } else { + repl.interpret("val people = sqlContext.applySchema(raw, schema)", + context); + repl.interpret("people.registerTempTable(\"people\")", context); + } + + InterpreterResult ret = sql.interpret( + "select name, age from people where name = 'gates'", context); + System.err.println("RET=" + ret.message()); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(Type.TABLE, ret.message().get(0).getType()); + assertEquals("name\tage\ngates\tnull\n", ret.message().get(0).getData()); + } + + @Test + public void testMaxResults() throws InterpreterException { + repl.interpret("case class P(age:Int)", context); + repl.interpret( + "val gr = sc.parallelize(Seq(P(1),P(2),P(3),P(4),P(5),P(6),P(7),P(8),P(9),P(10),P(11)))", + context); + if (isDataFrameSupported()) { + repl.interpret("gr.toDF.registerTempTable(\"gr\")", context); + } else { + repl.interpret("gr.registerTempTable(\"gr\")", context); + } + + InterpreterResult ret = sql.interpret("select * from gr", context); + assertEquals(InterpreterResult.Code.SUCCESS, ret.code()); + assertTrue(ret.message().get(1).getData().contains("alert-warning")); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterMatplotlibTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterMatplotlibTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterMatplotlibTest.java new file mode 100644 index 0000000..2d40871 --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterMatplotlibTest.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.InterpreterResult.Type; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.*; +import org.junit.rules.TemporaryFolder; +import org.junit.runners.MethodSorters; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; + +import static org.junit.Assert.*; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class PySparkInterpreterMatplotlibTest { + + @ClassRule + public static TemporaryFolder tmpDir = new TemporaryFolder(); + + static SparkInterpreter sparkInterpreter; + static PySparkInterpreter pyspark; + static InterpreterGroup intpGroup; + static Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreterTest.class); + static InterpreterContext context; + + public static class AltPySparkInterpreter extends PySparkInterpreter { + /** + * Since pyspark output is sent to an outputstream rather than + * being directly provided by interpret(), this subclass is created to + * override interpret() to append the result from the outputStream + * for the sake of convenience in testing. + */ + public AltPySparkInterpreter(Properties property) { + super(property); + } + + /** + * This code is mainly copied from RemoteInterpreterServer.java which + * normally handles this in real use cases. + */ + @Override + public InterpreterResult interpret(String st, InterpreterContext context) throws InterpreterException { + context.out.clear(); + InterpreterResult result = super.interpret(st, context); + List<InterpreterResultMessage> resultMessages = null; + try { + context.out.flush(); + resultMessages = context.out.toInterpreterResultMessage(); + } catch (IOException e) { + e.printStackTrace(); + } + resultMessages.addAll(result.message()); + + return new InterpreterResult(result.code(), resultMessages); + } + } + + private static Properties getPySparkTestProperties() throws IOException { + Properties p = new Properties(); + p.setProperty("spark.master", "local[*]"); + p.setProperty("spark.app.name", "Zeppelin Test"); + p.setProperty("zeppelin.spark.useHiveContext", "true"); + p.setProperty("zeppelin.spark.maxResult", "1000"); + p.setProperty("zeppelin.spark.importImplicit", "true"); + p.setProperty("zeppelin.pyspark.python", "python"); + p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath()); + p.setProperty("zeppelin.pyspark.useIPython", "false"); + return p; + } + + /** + * Get spark version number as a numerical value. + * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ... + */ + public static int getSparkVersionNumber() { + if (sparkInterpreter == null) { + return 0; + } + + String[] split = sparkInterpreter.getSparkContext().version().split("\\."); + int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]); + return version; + } + + @BeforeClass + public static void setUp() throws Exception { + intpGroup = new InterpreterGroup(); + intpGroup.put("note", new LinkedList<Interpreter>()); + context = new InterpreterContext("note", "id", null, "title", "text", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + new LocalResourcePool("id"), + new LinkedList<InterpreterContextRunner>(), + new InterpreterOutput(null)); + InterpreterContext.set(context); + + sparkInterpreter = new SparkInterpreter(getPySparkTestProperties()); + intpGroup.get("note").add(sparkInterpreter); + sparkInterpreter.setInterpreterGroup(intpGroup); + sparkInterpreter.open(); + + pyspark = new AltPySparkInterpreter(getPySparkTestProperties()); + intpGroup.get("note").add(pyspark); + pyspark.setInterpreterGroup(intpGroup); + pyspark.open(); + + context = new InterpreterContext("note", "id", null, "title", "text", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + new LocalResourcePool("id"), + new LinkedList<InterpreterContextRunner>(), + new InterpreterOutput(null)); + } + + @AfterClass + public static void tearDown() throws InterpreterException { + pyspark.close(); + sparkInterpreter.close(); + } + + @Test + public void dependenciesAreInstalled() throws InterpreterException { + // matplotlib + InterpreterResult ret = pyspark.interpret("import matplotlib", context); + assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code()); + + // inline backend + ret = pyspark.interpret("import backend_zinline", context); + assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code()); + } + + @Test + public void showPlot() throws InterpreterException { + // Simple plot test + InterpreterResult ret; + ret = pyspark.interpret("import matplotlib.pyplot as plt", context); + ret = pyspark.interpret("plt.close()", context); + ret = pyspark.interpret("z.configure_mpl(interactive=False)", context); + ret = pyspark.interpret("plt.plot([1, 2, 3])", context); + ret = pyspark.interpret("plt.show()", context); + + assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(ret.message().toString(), Type.HTML, ret.message().get(0).getType()); + assertTrue(ret.message().get(0).getData().contains("data:image/png;base64")); + assertTrue(ret.message().get(0).getData().contains("<div>")); + } + + @Test + // Test for when configuration is set to auto-close figures after show(). + public void testClose() throws InterpreterException { + InterpreterResult ret; + InterpreterResult ret1; + InterpreterResult ret2; + ret = pyspark.interpret("import matplotlib.pyplot as plt", context); + ret = pyspark.interpret("plt.close()", context); + ret = pyspark.interpret("z.configure_mpl(interactive=False, close=True, angular=False)", context); + ret = pyspark.interpret("plt.plot([1, 2, 3])", context); + ret1 = pyspark.interpret("plt.show()", context); + + // Second call to show() should print nothing, and Type should be TEXT. + // This is because when close=True, there should be no living instances + // of FigureManager, causing show() to return before setting the output + // type to HTML. + ret = pyspark.interpret("plt.show()", context); + assertEquals(0, ret.message().size()); + + // Now test that new plot is drawn. It should be identical to the + // previous one. + ret = pyspark.interpret("plt.plot([1, 2, 3])", context); + ret2 = pyspark.interpret("plt.show()", context); + assertEquals(ret1.message().get(0).getType(), ret2.message().get(0).getType()); + assertEquals(ret1.message().get(0).getData(), ret2.message().get(0).getData()); + } + + @Test + // Test for when configuration is set to not auto-close figures after show(). + public void testNoClose() throws InterpreterException { + InterpreterResult ret; + InterpreterResult ret1; + InterpreterResult ret2; + ret = pyspark.interpret("import matplotlib.pyplot as plt", context); + ret = pyspark.interpret("plt.close()", context); + ret = pyspark.interpret("z.configure_mpl(interactive=False, close=False, angular=False)", context); + ret = pyspark.interpret("plt.plot([1, 2, 3])", context); + ret1 = pyspark.interpret("plt.show()", context); + + // Second call to show() should print nothing, and Type should be HTML. + // This is because when close=False, there should be living instances + // of FigureManager, causing show() to set the output + // type to HTML even though the figure is inactive. + ret = pyspark.interpret("plt.show()", context); + assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code()); + + // Now test that plot can be reshown if it is updated. It should be + // different from the previous one because it will plot the same line + // again but in a different color. + ret = pyspark.interpret("plt.plot([1, 2, 3])", context); + ret2 = pyspark.interpret("plt.show()", context); + assertNotSame(ret1.message().get(0).getData(), ret2.message().get(0).getData()); + } + + @Test + // Test angular mode + public void testAngular() throws InterpreterException { + InterpreterResult ret; + ret = pyspark.interpret("import matplotlib.pyplot as plt", context); + ret = pyspark.interpret("plt.close()", context); + ret = pyspark.interpret("z.configure_mpl(interactive=False, close=False, angular=True)", context); + ret = pyspark.interpret("plt.plot([1, 2, 3])", context); + ret = pyspark.interpret("plt.show()", context); + assertEquals(ret.message().toString(), InterpreterResult.Code.SUCCESS, ret.code()); + assertEquals(ret.message().toString(), Type.ANGULAR, ret.message().get(0).getType()); + + // Check if the figure data is in the Angular Object Registry + AngularObjectRegistry registry = context.getAngularObjectRegistry(); + String figureData = registry.getAll("note", null).get(0).toString(); + assertTrue(figureData.contains("data:image/png;base64")); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java new file mode 100644 index 0000000..00972b4 --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.*; +import org.junit.rules.TemporaryFolder; +import org.junit.runners.MethodSorters; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.junit.Assert.*; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class PySparkInterpreterTest { + + @ClassRule + public static TemporaryFolder tmpDir = new TemporaryFolder(); + + static SparkInterpreter sparkInterpreter; + static PySparkInterpreter pySparkInterpreter; + static InterpreterGroup intpGroup; + static InterpreterContext context; + + private static Properties getPySparkTestProperties() throws IOException { + Properties p = new Properties(); + p.setProperty("spark.master", "local"); + p.setProperty("spark.app.name", "Zeppelin Test"); + p.setProperty("zeppelin.spark.useHiveContext", "true"); + p.setProperty("zeppelin.spark.maxResult", "1000"); + p.setProperty("zeppelin.spark.importImplicit", "true"); + p.setProperty("zeppelin.pyspark.python", "python"); + p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath()); + p.setProperty("zeppelin.pyspark.useIPython", "false"); + p.setProperty("zeppelin.spark.test", "true"); + return p; + } + + /** + * Get spark version number as a numerical value. + * eg. 1.1.x => 11, 1.2.x => 12, 1.3.x => 13 ... + */ + public static int getSparkVersionNumber() { + if (sparkInterpreter == null) { + return 0; + } + + String[] split = sparkInterpreter.getSparkContext().version().split("\\."); + int version = Integer.parseInt(split[0]) * 10 + Integer.parseInt(split[1]); + return version; + } + + @BeforeClass + public static void setUp() throws Exception { + intpGroup = new InterpreterGroup(); + intpGroup.put("note", new LinkedList<Interpreter>()); + + context = new InterpreterContext("note", "id", null, "title", "text", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + new AngularObjectRegistry(intpGroup.getId(), null), + new LocalResourcePool("id"), + new LinkedList<InterpreterContextRunner>(), + new InterpreterOutput(null)); + InterpreterContext.set(context); + + sparkInterpreter = new SparkInterpreter(getPySparkTestProperties()); + intpGroup.get("note").add(sparkInterpreter); + sparkInterpreter.setInterpreterGroup(intpGroup); + sparkInterpreter.open(); + + pySparkInterpreter = new PySparkInterpreter(getPySparkTestProperties()); + intpGroup.get("note").add(pySparkInterpreter); + pySparkInterpreter.setInterpreterGroup(intpGroup); + pySparkInterpreter.open(); + + + } + + @AfterClass + public static void tearDown() throws InterpreterException { + pySparkInterpreter.close(); + sparkInterpreter.close(); + } + + @Test + public void testBasicIntp() throws InterpreterException { + if (getSparkVersionNumber() > 11) { + assertEquals(InterpreterResult.Code.SUCCESS, + pySparkInterpreter.interpret("a = 1\n", context).code()); + } + + InterpreterResult result = pySparkInterpreter.interpret( + "from pyspark.streaming import StreamingContext\n" + + "import time\n" + + "ssc = StreamingContext(sc, 1)\n" + + "rddQueue = []\n" + + "for i in range(5):\n" + + " rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)]\n" + + "inputStream = ssc.queueStream(rddQueue)\n" + + "mappedStream = inputStream.map(lambda x: (x % 10, 1))\n" + + "reducedStream = mappedStream.reduceByKey(lambda a, b: a + b)\n" + + "reducedStream.pprint()\n" + + "ssc.start()\n" + + "time.sleep(6)\n" + + "ssc.stop(stopSparkContext=False, stopGraceFully=True)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + } + + @Test + public void testCompletion() throws InterpreterException { + if (getSparkVersionNumber() > 11) { + List<InterpreterCompletion> completions = pySparkInterpreter.completion("sc.", "sc.".length(), null); + assertTrue(completions.size() > 0); + } + } + + @Test + public void testRedefinitionZeppelinContext() throws InterpreterException { + if (getSparkVersionNumber() > 11) { + String redefinitionCode = "z = 1\n"; + String restoreCode = "z = __zeppelin__\n"; + String validCode = "z.input(\"test\")\n"; + + assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code()); + assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(redefinitionCode, context).code()); + assertEquals(InterpreterResult.Code.ERROR, pySparkInterpreter.interpret(validCode, context).code()); + assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(restoreCode, context).code()); + assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret(validCode, context).code()); + } + } + + private class infinityPythonJob implements Runnable { + @Override + public void run() { + String code = "import time\nwhile True:\n time.sleep(1)" ; + InterpreterResult ret = null; + try { + ret = pySparkInterpreter.interpret(code, context); + } catch (InterpreterException e) { + e.printStackTrace(); + } + assertNotNull(ret); + Pattern expectedMessage = Pattern.compile("KeyboardInterrupt"); + Matcher m = expectedMessage.matcher(ret.message().toString()); + assertTrue(m.find()); + } + } + + @Test + public void testCancelIntp() throws InterruptedException, InterpreterException { + if (getSparkVersionNumber() > 11) { + assertEquals(InterpreterResult.Code.SUCCESS, + pySparkInterpreter.interpret("a = 1\n", context).code()); + + Thread t = new Thread(new infinityPythonJob()); + t.start(); + Thread.sleep(5000); + pySparkInterpreter.cancel(context); + assertTrue(t.isAlive()); + t.join(2000); + assertFalse(t.isAlive()); + } + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java new file mode 100644 index 0000000..2d585f5 --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkRInterpreterTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.display.GUI; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class SparkRInterpreterTest { + + private SparkRInterpreter sparkRInterpreter; + private SparkInterpreter sparkInterpreter; + + + @Test + public void testSparkRInterpreter() throws IOException, InterruptedException, InterpreterException { + Properties properties = new Properties(); + properties.setProperty("spark.master", "local"); + properties.setProperty("spark.app.name", "test"); + properties.setProperty("zeppelin.spark.maxResult", "100"); + properties.setProperty("zeppelin.spark.test", "true"); + properties.setProperty("zeppelin.spark.useNew", "true"); + properties.setProperty("zeppelin.R.knitr", "true"); + + sparkRInterpreter = new SparkRInterpreter(properties); + sparkInterpreter = new SparkInterpreter(properties); + + InterpreterGroup interpreterGroup = new InterpreterGroup(); + interpreterGroup.addInterpreterToSession(new LazyOpenInterpreter(sparkRInterpreter), "session_1"); + interpreterGroup.addInterpreterToSession(new LazyOpenInterpreter(sparkInterpreter), "session_1"); + sparkRInterpreter.setInterpreterGroup(interpreterGroup); + sparkInterpreter.setInterpreterGroup(interpreterGroup); + + sparkRInterpreter.open(); + + InterpreterResult result = sparkRInterpreter.interpret("1+1", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertTrue(result.message().get(0).getData().contains("2")); + + result = sparkRInterpreter.interpret("sparkR.version()", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + if (result.message().get(0).getData().contains("2.")) { + // spark 2.x + result = sparkRInterpreter.interpret("df <- as.DataFrame(faithful)\nhead(df)", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertTrue(result.message().get(0).getData().contains("eruptions waiting")); + } else { + // spark 1.x + result = sparkRInterpreter.interpret("df <- createDataFrame(sqlContext, faithful)\nhead(df)", getInterpreterContext()); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertTrue(result.message().get(0).getData().contains("eruptions waiting")); + } + } + + private InterpreterContext getInterpreterContext() { + return new InterpreterContext( + "noteId", + "paragraphId", + "replName", + "paragraphTitle", + "paragraphText", + new AuthenticationInfo(), + new HashMap<String, Object>(), + new GUI(), + new GUI(), + new AngularObjectRegistry("spark", null), + null, + null, + null); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkVersionTest.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkVersionTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkVersionTest.java new file mode 100644 index 0000000..3dc8f4e --- /dev/null +++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/SparkVersionTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark; + +import static org.junit.Assert.*; + +import org.junit.Test; + +public class SparkVersionTest { + + @Test + public void testUnknownSparkVersion() { + assertEquals(99999, SparkVersion.fromVersionString("DEV-10.10").toNumber()); + } + + @Test + public void testUnsupportedVersion() { + assertTrue(SparkVersion.fromVersionString("9.9.9").isUnsupportedVersion()); + assertFalse(SparkVersion.fromVersionString("1.5.9").isUnsupportedVersion()); + assertTrue(SparkVersion.fromVersionString("0.9.0").isUnsupportedVersion()); + assertTrue(SparkVersion.UNSUPPORTED_FUTURE_VERSION.isUnsupportedVersion()); + // should support spark2 version of HDP 2.5 + assertFalse(SparkVersion.fromVersionString("2.0.0.2.5.0.0-1245").isUnsupportedVersion()); + } + + @Test + public void testSparkVersion() { + // test equals + assertEquals(SparkVersion.SPARK_1_2_0, SparkVersion.fromVersionString("1.2.0")); + assertEquals(SparkVersion.SPARK_1_5_0, SparkVersion.fromVersionString("1.5.0-SNAPSHOT")); + assertEquals(SparkVersion.SPARK_1_5_0, SparkVersion.fromVersionString("1.5.0-SNAPSHOT")); + // test spark2 version of HDP 2.5 + assertEquals(SparkVersion.SPARK_2_0_0, SparkVersion.fromVersionString("2.0.0.2.5.0.0-1245")); + + // test newer than + assertFalse(SparkVersion.SPARK_1_2_0.newerThan(SparkVersion.SPARK_1_2_0)); + assertFalse(SparkVersion.SPARK_1_2_0.newerThan(SparkVersion.SPARK_1_3_0)); + assertTrue(SparkVersion.SPARK_1_2_0.newerThan(SparkVersion.SPARK_1_1_0)); + + assertTrue(SparkVersion.SPARK_1_2_0.newerThanEquals(SparkVersion.SPARK_1_2_0)); + assertFalse(SparkVersion.SPARK_1_2_0.newerThanEquals(SparkVersion.SPARK_1_3_0)); + assertTrue(SparkVersion.SPARK_1_2_0.newerThanEquals(SparkVersion.SPARK_1_1_0)); + + // test older than + assertFalse(SparkVersion.SPARK_1_2_0.olderThan(SparkVersion.SPARK_1_2_0)); + assertFalse(SparkVersion.SPARK_1_2_0.olderThan(SparkVersion.SPARK_1_1_0)); + assertTrue(SparkVersion.SPARK_1_2_0.olderThan(SparkVersion.SPARK_1_3_0)); + + assertTrue(SparkVersion.SPARK_1_2_0.olderThanEquals(SparkVersion.SPARK_1_2_0)); + assertFalse(SparkVersion.SPARK_1_2_0.olderThanEquals(SparkVersion.SPARK_1_1_0)); + assertTrue(SparkVersion.SPARK_1_2_0.olderThanEquals(SparkVersion.SPARK_1_3_0)); + + // conversion + assertEquals(10200, SparkVersion.SPARK_1_2_0.toNumber()); + assertEquals("1.2.0", SparkVersion.SPARK_1_2_0.toString()); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/resources/log4j.properties b/spark/interpreter/src/test/resources/log4j.properties new file mode 100644 index 0000000..6958d4c --- /dev/null +++ b/spark/interpreter/src/test/resources/log4j.properties @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{ABSOLUTE} %5p %c:%L - %m%n +#log4j.appender.stdout.layout.ConversionPattern= +#%5p [%t] (%F:%L) - %m%n +#%-4r [%t] %-5p %c %x - %m%n +# + +# Root logger option +log4j.rootLogger=INFO, stdout + +#mute some noisy guys +log4j.logger.org.apache.hadoop.mapred=WARN +log4j.logger.org.apache.hadoop.hive.ql=WARN +log4j.logger.org.apache.hadoop.hive.metastore=WARN +log4j.logger.org.apache.haadoop.hive.service.HiveServer=WARN +log4j.logger.org.apache.zeppelin.scheduler=WARN + +log4j.logger.org.quartz=WARN +log4j.logger.DataNucleus=WARN +log4j.logger.DataNucleus.MetaData=ERROR +log4j.logger.DataNucleus.Datastore=ERROR + +# Log all JDBC parameters +log4j.logger.org.hibernate.type=ALL + +log4j.logger.org.apache.zeppelin.interpreter=DEBUG +log4j.logger.org.apache.zeppelin.spark=DEBUG + +log4j.logger.org.apache.zeppelin.python.IPythonInterpreter=DEBUG +log4j.logger.org.apache.zeppelin.python.IPythonClient=DEBUG +log4j.logger.org.apache.spark.repl.Main=INFO + http://git-wip-us.apache.org/repos/asf/zeppelin/blob/ca87f7d4/spark/interpreter/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala b/spark/interpreter/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala new file mode 100644 index 0000000..2638f17 --- /dev/null +++ b/spark/interpreter/src/test/scala/org/apache/zeppelin/spark/utils/DisplayFunctionsTest.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark.utils + +import java.io.ByteArrayOutputStream + +import org.apache.spark.rdd.RDD +import org.apache.spark.{SparkContext, SparkConf} +import org.scalatest._ +import org.scalatest.{BeforeAndAfter} + +case class Person(login : String, name: String, age: Int) + +class DisplayFunctionsTest extends FlatSpec with BeforeAndAfter with BeforeAndAfterEach with Matchers { + var sc: SparkContext = null + var testTuples:List[(String, String, Int)] = null + var testPersons:List[Person] = null + var testRDDTuples: RDD[(String,String,Int)] = null + var testRDDPersons: RDD[Person] = null + var stream: ByteArrayOutputStream = null + + before { + val sparkConf: SparkConf = new SparkConf(true) + .setAppName("test-DisplayFunctions") + .setMaster("local") + sc = new SparkContext(sparkConf) + testTuples = List(("jdoe", "John DOE", 32), ("hsue", "Helen SUE", 27), ("rsmith", "Richard SMITH", 45)) + testRDDTuples = sc.parallelize(testTuples) + testPersons = List(Person("jdoe", "John DOE", 32), Person("hsue", "Helen SUE", 27), Person("rsmith", "Richard SMITH", 45)) + testRDDPersons = sc.parallelize(testPersons) + } + + override def beforeEach() { + stream = new java.io.ByteArrayOutputStream() + super.beforeEach() // To be stackable, must call super.beforeEach + } + + + "DisplayFunctions" should "generate correct column headers for tuples" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "generate correct column headers for case class" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[Person](testRDDPersons).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "truncate exceeding column headers for tuples" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login","Name","Age","xxx","yyy") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "pad missing column headers with ColumnXXX for tuples" in { + implicit val sparkMaxResult = new SparkMaxResult(100) + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login") + } + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayUtils" should "restricts RDD to sparkMaxresult with implicit limit" in { + + implicit val sparkMaxResult = new SparkMaxResult(2) + + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display("Login") + } + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n") + } + + "DisplayUtils" should "restricts RDD to sparkMaxresult with explicit limit" in { + + implicit val sparkMaxResult = new SparkMaxResult(2) + + Console.withOut(stream) { + new DisplayRDDFunctions[(String,String,Int)](testRDDTuples).display(1,"Login") + } + + stream.toString("UTF-8") should be("%table Login\tColumn2\tColumn3\n" + + "jdoe\tJohn DOE\t32\n") + } + + "DisplayFunctions" should "display traversable of tuples" in { + + Console.withOut(stream) { + new DisplayTraversableFunctions[(String,String,Int)](testTuples).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayFunctions" should "display traversable of case class" in { + + Console.withOut(stream) { + new DisplayTraversableFunctions[Person](testPersons).display("Login","Name","Age") + } + + stream.toString("UTF-8") should be("%table Login\tName\tAge\n" + + "jdoe\tJohn DOE\t32\n" + + "hsue\tHelen SUE\t27\n" + + "rsmith\tRichard SMITH\t45\n") + } + + "DisplayUtils" should "display HTML" in { + DisplayUtils.html() should be ("%html ") + DisplayUtils.html("test") should be ("%html test") + } + + "DisplayUtils" should "display img" in { + DisplayUtils.img("http://www.google.com") should be ("<img src='http://www.google.com' />") + DisplayUtils.img64() should be ("%img ") + DisplayUtils.img64("abcde") should be ("%img abcde") + } + + override def afterEach() { + try super.afterEach() // To be stackable, must call super.afterEach + stream = null + } + + after { + sc.stop() + } + + +} + +