http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java new file mode 100644 index 0000000..0703ad7 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -0,0 +1,751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import com.google.gson.Gson; +import org.apache.commons.compress.utils.IOUtils; +import org.apache.commons.exec.CommandLine; +import org.apache.commons.exec.DefaultExecutor; +import org.apache.commons.exec.ExecuteException; +import org.apache.commons.exec.ExecuteResultHandler; +import org.apache.commons.exec.ExecuteWatchdog; +import org.apache.commons.exec.PumpStreamHandler; +import org.apache.commons.exec.environment.EnvironmentUtils; +import org.apache.commons.lang.StringUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterHookRegistry.HookType; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; +import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.interpreter.util.InterpreterOutputStream; +import org.apache.zeppelin.spark.dep.SparkDependencyContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import py4j.GatewayServer; + +import java.io.BufferedWriter; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.net.MalformedURLException; +import java.net.ServerSocket; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +/** + * + */ +public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler { + private static final Logger LOGGER = LoggerFactory.getLogger(PySparkInterpreter.class); + private GatewayServer gatewayServer; + private DefaultExecutor executor; + private int port; + private InterpreterOutputStream outputStream; + private BufferedWriter ins; + private PipedInputStream in; + private ByteArrayOutputStream input; + private String scriptPath; + boolean pythonscriptRunning = false; + private static final int MAX_TIMEOUT_SEC = 10; + private long pythonPid; + + private IPySparkInterpreter iPySparkInterpreter; + + public PySparkInterpreter(Properties property) { + super(property); + + pythonPid = -1; + try { + File scriptFile = File.createTempFile("zeppelin_pyspark-", ".py"); + scriptPath = scriptFile.getAbsolutePath(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private void createPythonScript() throws InterpreterException { + ClassLoader classLoader = getClass().getClassLoader(); + File out = new File(scriptPath); + + if (out.exists() && out.isDirectory()) { + throw new InterpreterException("Can't create python script " + out.getAbsolutePath()); + } + + try { + FileOutputStream outStream = new FileOutputStream(out); + IOUtils.copy( + classLoader.getResourceAsStream("python/zeppelin_pyspark.py"), + outStream); + outStream.close(); + } catch (IOException e) { + throw new InterpreterException(e); + } + + LOGGER.info("File {} created", scriptPath); + } + + @Override + public void open() throws InterpreterException { + // try IPySparkInterpreter first + iPySparkInterpreter = getIPySparkInterpreter(); + if (getProperty("zeppelin.pyspark.useIPython", "true").equals("true") && + StringUtils.isEmpty( + iPySparkInterpreter.checkIPythonPrerequisite(getPythonExec(getProperties())))) { + try { + iPySparkInterpreter.open(); + if (InterpreterContext.get() != null) { + // don't print it when it is in testing, just for easy output check in test. + InterpreterContext.get().out.write(("IPython is available, " + + "use IPython for PySparkInterpreter\n") + .getBytes()); + } + LOGGER.info("Use IPySparkInterpreter to replace PySparkInterpreter"); + return; + } catch (Exception e) { + LOGGER.warn("Fail to open IPySparkInterpreter", e); + } + } + iPySparkInterpreter = null; + if (getProperty("zeppelin.pyspark.useIPython", "true").equals("true")) { + // don't print it when it is in testing, just for easy output check in test. + try { + InterpreterContext.get().out.write(("IPython is not available, " + + "use the native PySparkInterpreter\n") + .getBytes()); + } catch (IOException e) { + LOGGER.warn("Fail to write InterpreterOutput", e); + } + } + + // Add matplotlib display hook + InterpreterGroup intpGroup = getInterpreterGroup(); + if (intpGroup != null && intpGroup.getInterpreterHookRegistry() != null) { + registerHook(HookType.POST_EXEC_DEV, "__zeppelin__._displayhook()"); + } + DepInterpreter depInterpreter = getDepInterpreter(); + + // load libraries from Dependency Interpreter + URL [] urls = new URL[0]; + List<URL> urlList = new LinkedList<>(); + + if (depInterpreter != null) { + SparkDependencyContext depc = depInterpreter.getDependencyContext(); + if (depc != null) { + List<File> files = depc.getFiles(); + if (files != null) { + for (File f : files) { + try { + urlList.add(f.toURI().toURL()); + } catch (MalformedURLException e) { + LOGGER.error("Error", e); + } + } + } + } + } + + String localRepo = getProperty("zeppelin.interpreter.localRepo"); + if (localRepo != null) { + File localRepoDir = new File(localRepo); + if (localRepoDir.exists()) { + File[] files = localRepoDir.listFiles(); + if (files != null) { + for (File f : files) { + try { + urlList.add(f.toURI().toURL()); + } catch (MalformedURLException e) { + LOGGER.error("Error", e); + } + } + } + } + } + + urls = urlList.toArray(urls); + ClassLoader oldCl = Thread.currentThread().getContextClassLoader(); + try { + URLClassLoader newCl = new URLClassLoader(urls, oldCl); + Thread.currentThread().setContextClassLoader(newCl); + createGatewayServerAndStartScript(); + } catch (Exception e) { + LOGGER.error("Error", e); + throw new InterpreterException(e); + } finally { + Thread.currentThread().setContextClassLoader(oldCl); + } + } + + private Map setupPySparkEnv() throws IOException, InterpreterException { + Map env = EnvironmentUtils.getProcEnvironment(); + + // only set PYTHONPATH in local or yarn-client mode. + // yarn-cluster will setup PYTHONPATH automatically. + SparkConf conf = getSparkConf(); + if (!conf.get("spark.submit.deployMode", "client").equals("cluster")) { + if (!env.containsKey("PYTHONPATH")) { + env.put("PYTHONPATH", PythonUtils.sparkPythonPath()); + } else { + env.put("PYTHONPATH", PythonUtils.sparkPythonPath()); + } + } + + // get additional class paths when using SPARK_SUBMIT and not using YARN-CLIENT + // also, add all packages to PYTHONPATH since there might be transitive dependencies + if (SparkInterpreter.useSparkSubmit() && + !getSparkInterpreter().isYarnMode()) { + + String sparkSubmitJars = getSparkConf().get("spark.jars").replace(",", ":"); + + if (!"".equals(sparkSubmitJars)) { + env.put("PYTHONPATH", env.get("PYTHONPATH") + sparkSubmitJars); + } + } + + LOGGER.info("PYTHONPATH: " + env.get("PYTHONPATH")); + + // set PYSPARK_PYTHON + if (getSparkConf().contains("spark.pyspark.python")) { + env.put("PYSPARK_PYTHON", getSparkConf().get("spark.pyspark.python")); + } + return env; + } + + // Run python shell + // Choose python in the order of + // PYSPARK_DRIVER_PYTHON > PYSPARK_PYTHON > zeppelin.pyspark.python + public static String getPythonExec(Properties properties) { + String pythonExec = properties.getProperty("zeppelin.pyspark.python", "python"); + if (System.getenv("PYSPARK_PYTHON") != null) { + pythonExec = System.getenv("PYSPARK_PYTHON"); + } + if (System.getenv("PYSPARK_DRIVER_PYTHON") != null) { + pythonExec = System.getenv("PYSPARK_DRIVER_PYTHON"); + } + return pythonExec; + } + + private void createGatewayServerAndStartScript() throws InterpreterException { + // create python script + createPythonScript(); + + port = findRandomOpenPortOnAllLocalInterfaces(); + + gatewayServer = new GatewayServer(this, port); + gatewayServer.start(); + + String pythonExec = getPythonExec(getProperties()); + LOGGER.info("pythonExec: " + pythonExec); + CommandLine cmd = CommandLine.parse(pythonExec); + cmd.addArgument(scriptPath, false); + cmd.addArgument(Integer.toString(port), false); + cmd.addArgument(Integer.toString(getSparkInterpreter().getSparkVersion().toNumber()), false); + executor = new DefaultExecutor(); + outputStream = new InterpreterOutputStream(LOGGER); + PipedOutputStream ps = new PipedOutputStream(); + in = null; + try { + in = new PipedInputStream(ps); + } catch (IOException e1) { + throw new InterpreterException(e1); + } + ins = new BufferedWriter(new OutputStreamWriter(ps)); + + input = new ByteArrayOutputStream(); + + PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in); + executor.setStreamHandler(streamHandler); + executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); + + try { + Map env = setupPySparkEnv(); + executor.execute(cmd, env, this); + pythonscriptRunning = true; + } catch (IOException e) { + throw new InterpreterException(e); + } + + + try { + input.write("import sys, getopt\n".getBytes()); + ins.flush(); + } catch (IOException e) { + throw new InterpreterException(e); + } + } + + private int findRandomOpenPortOnAllLocalInterfaces() throws InterpreterException { + int port; + try (ServerSocket socket = new ServerSocket(0);) { + port = socket.getLocalPort(); + socket.close(); + } catch (IOException e) { + throw new InterpreterException(e); + } + return port; + } + + @Override + public void close() throws InterpreterException { + if (iPySparkInterpreter != null) { + iPySparkInterpreter.close(); + return; + } + executor.getWatchdog().destroyProcess(); + new File(scriptPath).delete(); + gatewayServer.shutdown(); + } + + PythonInterpretRequest pythonInterpretRequest = null; + + /** + * + */ + public class PythonInterpretRequest { + public String statements; + public String jobGroup; + public String jobDescription; + + public PythonInterpretRequest(String statements, String jobGroup, + String jobDescription) { + this.statements = statements; + this.jobGroup = jobGroup; + this.jobDescription = jobDescription; + } + + public String statements() { + return statements; + } + + public String jobGroup() { + return jobGroup; + } + + public String jobDescription() { + return jobDescription; + } + } + + Integer statementSetNotifier = new Integer(0); + + public PythonInterpretRequest getStatements() { + synchronized (statementSetNotifier) { + while (pythonInterpretRequest == null) { + try { + statementSetNotifier.wait(1000); + } catch (InterruptedException e) { + } + } + PythonInterpretRequest req = pythonInterpretRequest; + pythonInterpretRequest = null; + return req; + } + } + + String statementOutput = null; + boolean statementError = false; + Integer statementFinishedNotifier = new Integer(0); + + public void setStatementsFinished(String out, boolean error) { + synchronized (statementFinishedNotifier) { + LOGGER.debug("Setting python statement output: " + out + ", error: " + error); + statementOutput = out; + statementError = error; + statementFinishedNotifier.notify(); + } + } + + boolean pythonScriptInitialized = false; + Integer pythonScriptInitializeNotifier = new Integer(0); + + public void onPythonScriptInitialized(long pid) { + pythonPid = pid; + synchronized (pythonScriptInitializeNotifier) { + LOGGER.debug("onPythonScriptInitialized is called"); + pythonScriptInitialized = true; + pythonScriptInitializeNotifier.notifyAll(); + } + } + + public void appendOutput(String message) throws IOException { + LOGGER.debug("Output from python process: " + message); + outputStream.getInterpreterOutput().write(message); + } + + @Override + public InterpreterResult interpret(String st, InterpreterContext context) + throws InterpreterException { + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + sparkInterpreter.populateSparkWebUrl(context); + if (sparkInterpreter.isUnsupportedSparkVersion()) { + return new InterpreterResult(Code.ERROR, "Spark " + + sparkInterpreter.getSparkVersion().toString() + " is not supported"); + } + + if (iPySparkInterpreter != null) { + return iPySparkInterpreter.interpret(st, context); + } + + if (!pythonscriptRunning) { + return new InterpreterResult(Code.ERROR, "python process not running" + + outputStream.toString()); + } + + outputStream.setInterpreterOutput(context.out); + + synchronized (pythonScriptInitializeNotifier) { + long startTime = System.currentTimeMillis(); + while (pythonScriptInitialized == false + && pythonscriptRunning + && System.currentTimeMillis() - startTime < MAX_TIMEOUT_SEC * 1000) { + try { + pythonScriptInitializeNotifier.wait(1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + + List<InterpreterResultMessage> errorMessage; + try { + context.out.flush(); + errorMessage = context.out.toInterpreterResultMessage(); + } catch (IOException e) { + throw new InterpreterException(e); + } + + + if (pythonscriptRunning == false) { + // python script failed to initialize and terminated + errorMessage.add(new InterpreterResultMessage( + InterpreterResult.Type.TEXT, "failed to start pyspark")); + return new InterpreterResult(Code.ERROR, errorMessage); + } + if (pythonScriptInitialized == false) { + // timeout. didn't get initialized message + errorMessage.add(new InterpreterResultMessage( + InterpreterResult.Type.TEXT, "pyspark is not responding")); + return new InterpreterResult(Code.ERROR, errorMessage); + } + + if (!sparkInterpreter.getSparkVersion().isPysparkSupported()) { + errorMessage.add(new InterpreterResultMessage( + InterpreterResult.Type.TEXT, + "pyspark " + sparkInterpreter.getSparkContext().version() + " is not supported")); + return new InterpreterResult(Code.ERROR, errorMessage); + } + String jobGroup = Utils.buildJobGroupId(context); + String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo()); + SparkZeppelinContext __zeppelin__ = sparkInterpreter.getZeppelinContext(); + __zeppelin__.setInterpreterContext(context); + __zeppelin__.setGui(context.getGui()); + __zeppelin__.setNoteGui(context.getNoteGui()); + pythonInterpretRequest = new PythonInterpretRequest(st, jobGroup, jobDesc); + statementOutput = null; + + synchronized (statementSetNotifier) { + statementSetNotifier.notify(); + } + + synchronized (statementFinishedNotifier) { + while (statementOutput == null) { + try { + statementFinishedNotifier.wait(1000); + } catch (InterruptedException e) { + } + } + } + + if (statementError) { + return new InterpreterResult(Code.ERROR, statementOutput); + } else { + + try { + context.out.flush(); + } catch (IOException e) { + throw new InterpreterException(e); + } + + return new InterpreterResult(Code.SUCCESS); + } + } + + public void interrupt() throws IOException, InterpreterException { + if (pythonPid > -1) { + LOGGER.info("Sending SIGINT signal to PID : " + pythonPid); + Runtime.getRuntime().exec("kill -SIGINT " + pythonPid); + } else { + LOGGER.warn("Non UNIX/Linux system, close the interpreter"); + close(); + } + } + + @Override + public void cancel(InterpreterContext context) throws InterpreterException { + if (iPySparkInterpreter != null) { + iPySparkInterpreter.cancel(context); + return; + } + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + sparkInterpreter.cancel(context); + try { + interrupt(); + } catch (IOException e) { + LOGGER.error("Error", e); + } + } + + @Override + public FormType getFormType() { + return FormType.NATIVE; + } + + @Override + public int getProgress(InterpreterContext context) throws InterpreterException { + if (iPySparkInterpreter != null) { + return iPySparkInterpreter.getProgress(context); + } + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + return sparkInterpreter.getProgress(context); + } + + + @Override + public List<InterpreterCompletion> completion(String buf, int cursor, + InterpreterContext interpreterContext) + throws InterpreterException { + if (iPySparkInterpreter != null) { + return iPySparkInterpreter.completion(buf, cursor, interpreterContext); + } + if (buf.length() < cursor) { + cursor = buf.length(); + } + String completionString = getCompletionTargetString(buf, cursor); + String completionCommand = "completion.getCompletion('" + completionString + "')"; + + //start code for completion + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + if (sparkInterpreter.isUnsupportedSparkVersion() || pythonscriptRunning == false) { + return new LinkedList<>(); + } + + pythonInterpretRequest = new PythonInterpretRequest(completionCommand, "", ""); + statementOutput = null; + + synchronized (statementSetNotifier) { + statementSetNotifier.notify(); + } + + String[] completionList = null; + synchronized (statementFinishedNotifier) { + long startTime = System.currentTimeMillis(); + while (statementOutput == null + && pythonscriptRunning) { + try { + if (System.currentTimeMillis() - startTime > MAX_TIMEOUT_SEC * 1000) { + LOGGER.error("pyspark completion didn't have response for {}sec.", MAX_TIMEOUT_SEC); + break; + } + statementFinishedNotifier.wait(1000); + } catch (InterruptedException e) { + // not working + LOGGER.info("wait drop"); + return new LinkedList<>(); + } + } + if (statementError) { + return new LinkedList<>(); + } + Gson gson = new Gson(); + completionList = gson.fromJson(statementOutput, String[].class); + } + //end code for completion + + if (completionList == null) { + return new LinkedList<>(); + } + + List<InterpreterCompletion> results = new LinkedList<>(); + for (String name: completionList) { + results.add(new InterpreterCompletion(name, name, StringUtils.EMPTY)); + } + return results; + } + + private String getCompletionTargetString(String text, int cursor) { + String[] completionSeqCharaters = {" ", "\n", "\t"}; + int completionEndPosition = cursor; + int completionStartPosition = cursor; + int indexOfReverseSeqPostion = cursor; + + String resultCompletionText = ""; + String completionScriptText = ""; + try { + completionScriptText = text.substring(0, cursor); + } + catch (Exception e) { + LOGGER.error(e.toString()); + return null; + } + completionEndPosition = completionScriptText.length(); + + String tempReverseCompletionText = new StringBuilder(completionScriptText).reverse().toString(); + + for (String seqCharacter : completionSeqCharaters) { + indexOfReverseSeqPostion = tempReverseCompletionText.indexOf(seqCharacter); + + if (indexOfReverseSeqPostion < completionStartPosition && indexOfReverseSeqPostion > 0) { + completionStartPosition = indexOfReverseSeqPostion; + } + + } + + if (completionStartPosition == completionEndPosition) { + completionStartPosition = 0; + } + else + { + completionStartPosition = completionEndPosition - completionStartPosition; + } + resultCompletionText = completionScriptText.substring( + completionStartPosition , completionEndPosition); + + return resultCompletionText; + } + + + private SparkInterpreter getSparkInterpreter() throws InterpreterException { + LazyOpenInterpreter lazy = null; + SparkInterpreter spark = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(SparkInterpreter.class.getName()); + + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + lazy = (LazyOpenInterpreter) p; + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + spark = (SparkInterpreter) p; + + if (lazy != null) { + lazy.open(); + } + return spark; + } + + private IPySparkInterpreter getIPySparkInterpreter() { + LazyOpenInterpreter lazy = null; + IPySparkInterpreter iPySpark = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(IPySparkInterpreter.class.getName()); + + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + lazy = (LazyOpenInterpreter) p; + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + iPySpark = (IPySparkInterpreter) p; + return iPySpark; + } + + public SparkZeppelinContext getZeppelinContext() throws InterpreterException { + SparkInterpreter sparkIntp = getSparkInterpreter(); + if (sparkIntp != null) { + return getSparkInterpreter().getZeppelinContext(); + } else { + return null; + } + } + + public JavaSparkContext getJavaSparkContext() throws InterpreterException { + SparkInterpreter intp = getSparkInterpreter(); + if (intp == null) { + return null; + } else { + return new JavaSparkContext(intp.getSparkContext()); + } + } + + public Object getSparkSession() throws InterpreterException { + SparkInterpreter intp = getSparkInterpreter(); + if (intp == null) { + return null; + } else { + return intp.getSparkSession(); + } + } + + public SparkConf getSparkConf() throws InterpreterException { + JavaSparkContext sc = getJavaSparkContext(); + if (sc == null) { + return null; + } else { + return getJavaSparkContext().getConf(); + } + } + + public SQLContext getSQLContext() throws InterpreterException { + SparkInterpreter intp = getSparkInterpreter(); + if (intp == null) { + return null; + } else { + return intp.getSQLContext(); + } + } + + private DepInterpreter getDepInterpreter() { + Interpreter p = getInterpreterInTheSameSessionByClassName(DepInterpreter.class.getName()); + if (p == null) { + return null; + } + + while (p instanceof WrappedInterpreter) { + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + return (DepInterpreter) p; + } + + + @Override + public void onProcessComplete(int exitValue) { + pythonscriptRunning = false; + LOGGER.info("python process terminated. exit code " + exitValue); + } + + @Override + public void onProcessFailed(ExecuteException e) { + pythonscriptRunning = false; + LOGGER.error("python process failed", e); + } +}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PythonUtils.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PythonUtils.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PythonUtils.java new file mode 100644 index 0000000..8182690 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/PythonUtils.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.zeppelin.spark; + +import org.apache.commons.lang3.StringUtils; + +import java.io.File; +import java.io.FilenameFilter; +import java.util.ArrayList; +import java.util.List; + +/** + * Util class for PySpark + */ +public class PythonUtils { + + /** + * Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from ZEPPELIN_HOME + * when it is embedded mode. + * + * This method will called in zeppelin server process and spark driver process when it is + * local or yarn-client mode. + */ + public static String sparkPythonPath() { + List<String> pythonPath = new ArrayList<String>(); + String sparkHome = System.getenv("SPARK_HOME"); + String zeppelinHome = System.getenv("ZEPPELIN_HOME"); + if (zeppelinHome == null) { + zeppelinHome = new File("..").getAbsolutePath(); + } + if (sparkHome != null) { + // non-embedded mode when SPARK_HOME is specified. + File pyspark = new File(sparkHome, "python/lib/pyspark.zip"); + if (!pyspark.exists()) { + throw new RuntimeException("No pyspark.zip found under " + sparkHome + "/python/lib"); + } + pythonPath.add(pyspark.getAbsolutePath()); + File[] py4j = new File(sparkHome + "/python/lib").listFiles(new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.startsWith("py4j"); + } + }); + if (py4j.length == 0) { + throw new RuntimeException("No py4j files found under " + sparkHome + "/python/lib"); + } else if (py4j.length > 1) { + throw new RuntimeException("Multiple py4j files found under " + sparkHome + "/python/lib"); + } else { + pythonPath.add(py4j[0].getAbsolutePath()); + } + } else { + // embedded mode + File pyspark = new File(zeppelinHome, "interpreter/spark/pyspark/pyspark.zip"); + if (!pyspark.exists()) { + throw new RuntimeException("No pyspark.zip found: " + pyspark.getAbsolutePath()); + } + pythonPath.add(pyspark.getAbsolutePath()); + File[] py4j = new File(zeppelinHome, "interpreter/spark/pyspark").listFiles( + new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.startsWith("py4j"); + } + }); + if (py4j.length == 0) { + throw new RuntimeException("No py4j files found under " + zeppelinHome + + "/interpreter/spark/pyspark"); + } else if (py4j.length > 1) { + throw new RuntimeException("Multiple py4j files found under " + sparkHome + + "/interpreter/spark/pyspark"); + } else { + pythonPath.add(py4j[0].getAbsolutePath()); + } + } + + // add ${ZEPPELIN_HOME}/interpreter/lib/python for all the cases + pythonPath.add(zeppelinHome + "/interpreter/lib/python"); + return StringUtils.join(pythonPath, ":"); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java new file mode 100644 index 0000000..d9be573 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkInterpreter.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Properties; + +/** + * It is the Wrapper of OldSparkInterpreter & NewSparkInterpreter. + * Property zeppelin.spark.useNew control which one to use. + */ +public class SparkInterpreter extends AbstractSparkInterpreter { + + private static final Logger LOGGER = LoggerFactory.getLogger(SparkInterpreter.class); + + // either OldSparkInterpreter or NewSparkInterpreter + private AbstractSparkInterpreter delegation; + + + public SparkInterpreter(Properties properties) { + super(properties); + if (Boolean.parseBoolean(properties.getProperty("zeppelin.spark.useNew", "false"))) { + delegation = new NewSparkInterpreter(properties); + } else { + delegation = new OldSparkInterpreter(properties); + } + } + + @Override + public void open() throws InterpreterException { + delegation.setInterpreterGroup(getInterpreterGroup()); + delegation.setUserName(getUserName()); + delegation.setClassloaderUrls(getClassloaderUrls()); + + delegation.open(); + } + + @Override + public void close() throws InterpreterException { + delegation.close(); + } + + @Override + public InterpreterResult interpret(String st, InterpreterContext context) + throws InterpreterException { + return delegation.interpret(st, context); + } + + @Override + public void cancel(InterpreterContext context) throws InterpreterException { + delegation.cancel(context); + } + + @Override + public List<InterpreterCompletion> completion(String buf, + int cursor, + InterpreterContext interpreterContext) + throws InterpreterException { + return delegation.completion(buf, cursor, interpreterContext); + } + + @Override + public FormType getFormType() { + return FormType.NATIVE; + } + + @Override + public int getProgress(InterpreterContext context) throws InterpreterException { + return delegation.getProgress(context); + } + + public AbstractSparkInterpreter getDelegation() { + return delegation; + } + + + @Override + public SparkContext getSparkContext() { + return delegation.getSparkContext(); + } + + @Override + public SQLContext getSQLContext() { + return delegation.getSQLContext(); + } + + @Override + public Object getSparkSession() { + return delegation.getSparkSession(); + } + + @Override + public boolean isSparkContextInitialized() { + return delegation.isSparkContextInitialized(); + } + + @Override + public SparkVersion getSparkVersion() { + return delegation.getSparkVersion(); + } + + @Override + public JavaSparkContext getJavaSparkContext() { + return delegation.getJavaSparkContext(); + } + + @Override + public void populateSparkWebUrl(InterpreterContext ctx) { + delegation.populateSparkWebUrl(ctx); + } + + @Override + public SparkZeppelinContext getZeppelinContext() { + return delegation.getZeppelinContext(); + } + + @Override + public String getSparkUIUrl() { + return delegation.getSparkUIUrl(); + } + + public boolean isUnsupportedSparkVersion() { + return delegation.isUnsupportedSparkVersion(); + } + + public boolean isYarnMode() { + String master = getProperty("master"); + if (master == null) { + master = getProperty("spark.master", "local[*]"); + } + return master.startsWith("yarn"); + } + + public static boolean useSparkSubmit() { + return null != System.getenv("SPARK_SUBMIT"); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkRInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkRInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkRInterpreter.java new file mode 100644 index 0000000..dbaeabe --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkRInterpreter.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import static org.apache.zeppelin.spark.ZeppelinRDisplay.render; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.apache.spark.SparkContext; +import org.apache.spark.SparkRBackend; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.scheduler.Scheduler; +import org.apache.zeppelin.scheduler.SchedulerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +/** + * R and SparkR interpreter with visualization support. + */ +public class SparkRInterpreter extends Interpreter { + private static final Logger logger = LoggerFactory.getLogger(SparkRInterpreter.class); + + private static String renderOptions; + private SparkInterpreter sparkInterpreter; + private ZeppelinR zeppelinR; + private SparkContext sc; + private JavaSparkContext jsc; + + public SparkRInterpreter(Properties property) { + super(property); + } + + @Override + public void open() throws InterpreterException { + String rCmdPath = getProperty("zeppelin.R.cmd", "R"); + String sparkRLibPath; + + if (System.getenv("SPARK_HOME") != null) { + sparkRLibPath = System.getenv("SPARK_HOME") + "/R/lib"; + } else { + sparkRLibPath = System.getenv("ZEPPELIN_HOME") + "/interpreter/spark/R/lib"; + // workaround to make sparkr work without SPARK_HOME + System.setProperty("spark.test.home", System.getenv("ZEPPELIN_HOME") + "/interpreter/spark"); + } + synchronized (SparkRBackend.backend()) { + if (!SparkRBackend.isStarted()) { + SparkRBackend.init(); + SparkRBackend.start(); + } + } + + int port = SparkRBackend.port(); + + this.sparkInterpreter = getSparkInterpreter(); + this.sc = sparkInterpreter.getSparkContext(); + this.jsc = sparkInterpreter.getJavaSparkContext(); + SparkVersion sparkVersion = new SparkVersion(sc.version()); + ZeppelinRContext.setSparkContext(sc); + ZeppelinRContext.setJavaSparkContext(jsc); + if (Utils.isSpark2()) { + ZeppelinRContext.setSparkSession(sparkInterpreter.getSparkSession()); + } + ZeppelinRContext.setSqlContext(sparkInterpreter.getSQLContext()); + ZeppelinRContext.setZeppelinContext(sparkInterpreter.getZeppelinContext()); + + zeppelinR = new ZeppelinR(rCmdPath, sparkRLibPath, port, sparkVersion); + try { + zeppelinR.open(); + } catch (IOException e) { + logger.error("Exception while opening SparkRInterpreter", e); + throw new InterpreterException(e); + } + + if (useKnitr()) { + zeppelinR.eval("library('knitr')"); + } + renderOptions = getProperty("zeppelin.R.render.options"); + } + + String getJobGroup(InterpreterContext context){ + return "zeppelin-" + context.getParagraphId(); + } + + @Override + public InterpreterResult interpret(String lines, InterpreterContext interpreterContext) + throws InterpreterException { + + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + sparkInterpreter.populateSparkWebUrl(interpreterContext); + if (sparkInterpreter.isUnsupportedSparkVersion()) { + return new InterpreterResult(InterpreterResult.Code.ERROR, "Spark " + + sparkInterpreter.getSparkVersion().toString() + " is not supported"); + } + + String jobGroup = Utils.buildJobGroupId(interpreterContext); + String jobDesc = "Started by: " + + Utils.getUserName(interpreterContext.getAuthenticationInfo()); + sparkInterpreter.getSparkContext().setJobGroup(jobGroup, jobDesc, false); + + String imageWidth = getProperty("zeppelin.R.image.width"); + + String[] sl = lines.split("\n"); + if (sl[0].contains("{") && sl[0].contains("}")) { + String jsonConfig = sl[0].substring(sl[0].indexOf("{"), sl[0].indexOf("}") + 1); + ObjectMapper m = new ObjectMapper(); + try { + JsonNode rootNode = m.readTree(jsonConfig); + JsonNode imageWidthNode = rootNode.path("imageWidth"); + if (!imageWidthNode.isMissingNode()) imageWidth = imageWidthNode.textValue(); + } + catch (Exception e) { + logger.warn("Can not parse json config: " + jsonConfig, e); + } + finally { + lines = lines.replace(jsonConfig, ""); + } + } + + String setJobGroup = ""; + // assign setJobGroup to dummy__, otherwise it would print NULL for this statement + if (Utils.isSpark2()) { + setJobGroup = "dummy__ <- setJobGroup(\"" + jobGroup + + "\", \" +" + jobDesc + "\", TRUE)"; + } else if (getSparkInterpreter().getSparkVersion().newerThanEquals(SparkVersion.SPARK_1_5_0)) { + setJobGroup = "dummy__ <- setJobGroup(sc, \"" + jobGroup + + "\", \"" + jobDesc + "\", TRUE)"; + } + logger.debug("set JobGroup:" + setJobGroup); + lines = setJobGroup + "\n" + lines; + + try { + // render output with knitr + if (useKnitr()) { + zeppelinR.setInterpreterOutput(null); + zeppelinR.set(".zcmd", "\n```{r " + renderOptions + "}\n" + lines + "\n```"); + zeppelinR.eval(".zres <- knit2html(text=.zcmd)"); + String html = zeppelinR.getS0(".zres"); + + RDisplay rDisplay = render(html, imageWidth); + + return new InterpreterResult( + rDisplay.code(), + rDisplay.type(), + rDisplay.content() + ); + } else { + // alternatively, stream the output (without knitr) + zeppelinR.setInterpreterOutput(interpreterContext.out); + zeppelinR.eval(lines); + return new InterpreterResult(InterpreterResult.Code.SUCCESS, ""); + } + } catch (Exception e) { + logger.error("Exception while connecting to R", e); + return new InterpreterResult(InterpreterResult.Code.ERROR, e.getMessage()); + } finally { + try { + } catch (Exception e) { + // Do nothing... + } + } + } + + @Override + public void close() { + zeppelinR.close(); + } + + @Override + public void cancel(InterpreterContext context) { + if (this.sc != null) { + sc.cancelJobGroup(getJobGroup(context)); + } + } + + @Override + public FormType getFormType() { + return FormType.NONE; + } + + @Override + public int getProgress(InterpreterContext context) throws InterpreterException { + if (sparkInterpreter != null) { + return sparkInterpreter.getProgress(context); + } else { + return 0; + } + } + + @Override + public Scheduler getScheduler() { + return SchedulerFactory.singleton().createOrGetFIFOScheduler( + SparkRInterpreter.class.getName() + this.hashCode()); + } + + @Override + public List<InterpreterCompletion> completion(String buf, int cursor, + InterpreterContext interpreterContext) { + return new ArrayList<>(); + } + + private SparkInterpreter getSparkInterpreter() throws InterpreterException { + LazyOpenInterpreter lazy = null; + SparkInterpreter spark = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(SparkInterpreter.class.getName()); + + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + lazy = (LazyOpenInterpreter) p; + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + spark = (SparkInterpreter) p; + + if (lazy != null) { + lazy.open(); + } + return spark; + } + + private boolean useKnitr() { + try { + return Boolean.parseBoolean(getProperty("zeppelin.R.knitr")); + } catch (Exception e) { + return false; + } + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java new file mode 100644 index 0000000..9709f9e --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkSqlInterpreter.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.spark.SparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; +import org.apache.zeppelin.interpreter.WrappedInterpreter; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.scheduler.Scheduler; +import org.apache.zeppelin.scheduler.SchedulerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Spark SQL interpreter for Zeppelin. + */ +public class SparkSqlInterpreter extends Interpreter { + private Logger logger = LoggerFactory.getLogger(SparkSqlInterpreter.class); + + public static final String MAX_RESULTS = "zeppelin.spark.maxResult"; + + AtomicInteger num = new AtomicInteger(0); + + private int maxResult; + + public SparkSqlInterpreter(Properties property) { + super(property); + } + + @Override + public void open() { + this.maxResult = Integer.parseInt(getProperty(MAX_RESULTS)); + } + + private SparkInterpreter getSparkInterpreter() throws InterpreterException { + LazyOpenInterpreter lazy = null; + SparkInterpreter spark = null; + Interpreter p = getInterpreterInTheSameSessionByClassName(SparkInterpreter.class.getName()); + + while (p instanceof WrappedInterpreter) { + if (p instanceof LazyOpenInterpreter) { + lazy = (LazyOpenInterpreter) p; + } + p = ((WrappedInterpreter) p).getInnerInterpreter(); + } + spark = (SparkInterpreter) p; + + if (lazy != null) { + lazy.open(); + } + return spark; + } + + public boolean concurrentSQL() { + return Boolean.parseBoolean(getProperty("zeppelin.spark.concurrentSQL")); + } + + @Override + public void close() {} + + @Override + public InterpreterResult interpret(String st, InterpreterContext context) + throws InterpreterException { + SQLContext sqlc = null; + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + + if (sparkInterpreter.isUnsupportedSparkVersion()) { + return new InterpreterResult(Code.ERROR, "Spark " + + sparkInterpreter.getSparkVersion().toString() + " is not supported"); + } + + sparkInterpreter.populateSparkWebUrl(context); + sparkInterpreter.getZeppelinContext().setInterpreterContext(context); + sqlc = sparkInterpreter.getSQLContext(); + SparkContext sc = sqlc.sparkContext(); + if (concurrentSQL()) { + sc.setLocalProperty("spark.scheduler.pool", "fair"); + } else { + sc.setLocalProperty("spark.scheduler.pool", null); + } + + String jobDesc = "Started by: " + Utils.getUserName(context.getAuthenticationInfo()); + sc.setJobGroup(Utils.buildJobGroupId(context), jobDesc, false); + Object rdd = null; + try { + // method signature of sqlc.sql() is changed + // from def sql(sqlText: String): SchemaRDD (1.2 and prior) + // to def sql(sqlText: String): DataFrame (1.3 and later). + // Therefore need to use reflection to keep binary compatibility for all spark versions. + Method sqlMethod = sqlc.getClass().getMethod("sql", String.class); + rdd = sqlMethod.invoke(sqlc, st); + } catch (InvocationTargetException ite) { + if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace"))) { + throw new InterpreterException(ite); + } + logger.error("Invocation target exception", ite); + String msg = ite.getTargetException().getMessage() + + "\nset zeppelin.spark.sql.stacktrace = true to see full stacktrace"; + return new InterpreterResult(Code.ERROR, msg); + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException e) { + throw new InterpreterException(e); + } + + String msg = sparkInterpreter.getZeppelinContext().showData(rdd); + sc.clearJobGroup(); + return new InterpreterResult(Code.SUCCESS, msg); + } + + @Override + public void cancel(InterpreterContext context) throws InterpreterException { + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + SQLContext sqlc = sparkInterpreter.getSQLContext(); + SparkContext sc = sqlc.sparkContext(); + + sc.cancelJobGroup(Utils.buildJobGroupId(context)); + } + + @Override + public FormType getFormType() { + return FormType.SIMPLE; + } + + + @Override + public int getProgress(InterpreterContext context) throws InterpreterException { + SparkInterpreter sparkInterpreter = getSparkInterpreter(); + return sparkInterpreter.getProgress(context); + } + + @Override + public Scheduler getScheduler() { + if (concurrentSQL()) { + int maxConcurrency = 10; + return SchedulerFactory.singleton().createOrGetParallelScheduler( + SparkSqlInterpreter.class.getName() + this.hashCode(), maxConcurrency); + } else { + // getSparkInterpreter() calls open() inside. + // That means if SparkInterpreter is not opened, it'll wait until SparkInterpreter open. + // In this moment UI displays 'READY' or 'FINISHED' instead of 'PENDING' or 'RUNNING'. + // It's because of scheduler is not created yet, and scheduler is created by this function. + // Therefore, we can still use getSparkInterpreter() here, but it's better and safe + // to getSparkInterpreter without opening it. + + Interpreter intp = + getInterpreterInTheSameSessionByClassName(SparkInterpreter.class.getName()); + if (intp != null) { + return intp.getScheduler(); + } else { + return null; + } + } + } + + @Override + public List<InterpreterCompletion> completion(String buf, int cursor, + InterpreterContext interpreterContext) { + return null; + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkVersion.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkVersion.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkVersion.java new file mode 100644 index 0000000..4b02798 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkVersion.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Provide reading comparing capability of spark version returned from SparkContext.version() + */ +public class SparkVersion { + Logger logger = LoggerFactory.getLogger(SparkVersion.class); + + public static final SparkVersion SPARK_1_0_0 = SparkVersion.fromVersionString("1.0.0"); + public static final SparkVersion SPARK_1_1_0 = SparkVersion.fromVersionString("1.1.0"); + public static final SparkVersion SPARK_1_2_0 = SparkVersion.fromVersionString("1.2.0"); + public static final SparkVersion SPARK_1_3_0 = SparkVersion.fromVersionString("1.3.0"); + public static final SparkVersion SPARK_1_4_0 = SparkVersion.fromVersionString("1.4.0"); + public static final SparkVersion SPARK_1_5_0 = SparkVersion.fromVersionString("1.5.0"); + public static final SparkVersion SPARK_1_6_0 = SparkVersion.fromVersionString("1.6.0"); + + public static final SparkVersion SPARK_2_0_0 = SparkVersion.fromVersionString("2.0.0"); + public static final SparkVersion SPARK_2_3_0 = SparkVersion.fromVersionString("2.3.0"); + + public static final SparkVersion MIN_SUPPORTED_VERSION = SPARK_1_0_0; + public static final SparkVersion UNSUPPORTED_FUTURE_VERSION = SPARK_2_3_0; + + private int version; + private String versionString; + + SparkVersion(String versionString) { + this.versionString = versionString; + + try { + int pos = versionString.indexOf('-'); + + String numberPart = versionString; + if (pos > 0) { + numberPart = versionString.substring(0, pos); + } + + String versions[] = numberPart.split("\\."); + int major = Integer.parseInt(versions[0]); + int minor = Integer.parseInt(versions[1]); + int patch = Integer.parseInt(versions[2]); + // version is always 5 digits. (e.g. 2.0.0 -> 20000, 1.6.2 -> 10602) + version = Integer.parseInt(String.format("%d%02d%02d", major, minor, patch)); + } catch (Exception e) { + logger.error("Can not recognize Spark version " + versionString + + ". Assume it's a future release", e); + + // assume it is future release + version = 99999; + } + } + + public int toNumber() { + return version; + } + + public String toString() { + return versionString; + } + + public boolean isUnsupportedVersion() { + return olderThan(MIN_SUPPORTED_VERSION) || newerThanEquals(UNSUPPORTED_FUTURE_VERSION); + } + + public static SparkVersion fromVersionString(String versionString) { + return new SparkVersion(versionString); + } + + public boolean isPysparkSupported() { + return this.newerThanEquals(SPARK_1_2_0); + } + + public boolean isSparkRSupported() { + return this.newerThanEquals(SPARK_1_4_0); + } + + public boolean hasDataFrame() { + return this.newerThanEquals(SPARK_1_4_0); + } + + public boolean getProgress1_0() { + return this.olderThan(SPARK_1_1_0); + } + + public boolean oldLoadFilesMethodName() { + return this.olderThan(SPARK_1_3_0); + } + + public boolean oldSqlContextImplicits() { + return this.olderThan(SPARK_1_3_0); + } + + public boolean equals(Object versionToCompare) { + return version == ((SparkVersion) versionToCompare).version; + } + + public boolean newerThan(SparkVersion versionToCompare) { + return version > versionToCompare.version; + } + + public boolean newerThanEquals(SparkVersion versionToCompare) { + return version >= versionToCompare.version; + } + + public boolean olderThan(SparkVersion versionToCompare) { + return version < versionToCompare.version; + } + + public boolean olderThanEquals(SparkVersion versionToCompare) { + return version <= versionToCompare.version; + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java new file mode 100644 index 0000000..8847039 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/SparkZeppelinContext.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import com.google.common.collect.Lists; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.zeppelin.annotation.ZeppelinApi; +import org.apache.zeppelin.display.AngularObjectWatcher; +import org.apache.zeppelin.display.Input; +import org.apache.zeppelin.display.ui.OptionInput; +import org.apache.zeppelin.interpreter.*; +import scala.Tuple2; +import scala.Unit; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.*; + +import static scala.collection.JavaConversions.asJavaCollection; +import static scala.collection.JavaConversions.asJavaIterable; +import static scala.collection.JavaConversions.collectionAsScalaIterable; + +/** + * ZeppelinContext for Spark + */ +public class SparkZeppelinContext extends BaseZeppelinContext { + + private SparkContext sc; + private List<Class> supportedClasses; + private Map<String, String> interpreterClassMap; + + public SparkZeppelinContext( + SparkContext sc, + InterpreterHookRegistry hooks, + int maxResult) { + super(hooks, maxResult); + this.sc = sc; + + interpreterClassMap = new HashMap(); + interpreterClassMap.put("spark", "org.apache.zeppelin.spark.SparkInterpreter"); + interpreterClassMap.put("sql", "org.apache.zeppelin.spark.SparkSqlInterpreter"); + interpreterClassMap.put("dep", "org.apache.zeppelin.spark.DepInterpreter"); + interpreterClassMap.put("pyspark", "org.apache.zeppelin.spark.PySparkInterpreter"); + + this.supportedClasses = new ArrayList<>(); + try { + supportedClasses.add(this.getClass().forName("org.apache.spark.sql.Dataset")); + } catch (ClassNotFoundException e) { + } + + try { + supportedClasses.add(this.getClass().forName("org.apache.spark.sql.DataFrame")); + } catch (ClassNotFoundException e) { + } + + try { + supportedClasses.add(this.getClass().forName("org.apache.spark.sql.SchemaRDD")); + } catch (ClassNotFoundException e) { + } + + if (supportedClasses.isEmpty()) { + throw new RuntimeException("Can not load Dataset/DataFrame/SchemaRDD class"); + } + } + + @Override + public List<Class> getSupportedClasses() { + return supportedClasses; + } + + @Override + public Map<String, String> getInterpreterClassMap() { + return interpreterClassMap; + } + + @Override + public String showData(Object df) { + Object[] rows = null; + Method take; + String jobGroup = Utils.buildJobGroupId(interpreterContext); + sc.setJobGroup(jobGroup, "Zeppelin", false); + + try { + // convert it to DataFrame if it is Dataset, as we will iterate all the records + // and assume it is type Row. + if (df.getClass().getCanonicalName().equals("org.apache.spark.sql.Dataset")) { + Method convertToDFMethod = df.getClass().getMethod("toDF"); + df = convertToDFMethod.invoke(df); + } + take = df.getClass().getMethod("take", int.class); + rows = (Object[]) take.invoke(df, maxResult + 1); + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException | ClassCastException e) { + sc.clearJobGroup(); + throw new RuntimeException(e); + } + + List<Attribute> columns = null; + // get field names + try { + // Use reflection because of classname returned by queryExecution changes from + // Spark <1.5.2 org.apache.spark.sql.SQLContext$QueryExecution + // Spark 1.6.0> org.apache.spark.sql.hive.HiveContext$QueryExecution + Object qe = df.getClass().getMethod("queryExecution").invoke(df); + Object a = qe.getClass().getMethod("analyzed").invoke(qe); + scala.collection.Seq seq = (scala.collection.Seq) a.getClass().getMethod("output").invoke(a); + + columns = (List<Attribute>) scala.collection.JavaConverters.seqAsJavaListConverter(seq) + .asJava(); + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException e) { + throw new RuntimeException(e); + } + + StringBuilder msg = new StringBuilder(); + msg.append("%table "); + for (Attribute col : columns) { + msg.append(col.name() + "\t"); + } + String trim = msg.toString().trim(); + msg = new StringBuilder(trim); + msg.append("\n"); + + // ArrayType, BinaryType, BooleanType, ByteType, DecimalType, DoubleType, DynamicType, + // FloatType, FractionalType, IntegerType, IntegralType, LongType, MapType, NativeType, + // NullType, NumericType, ShortType, StringType, StructType + + try { + for (int r = 0; r < maxResult && r < rows.length; r++) { + Object row = rows[r]; + Method isNullAt = row.getClass().getMethod("isNullAt", int.class); + Method apply = row.getClass().getMethod("apply", int.class); + + for (int i = 0; i < columns.size(); i++) { + if (!(Boolean) isNullAt.invoke(row, i)) { + msg.append(apply.invoke(row, i).toString()); + } else { + msg.append("null"); + } + if (i != columns.size() - 1) { + msg.append("\t"); + } + } + msg.append("\n"); + } + } catch (NoSuchMethodException | SecurityException | IllegalAccessException + | IllegalArgumentException | InvocationTargetException e) { + throw new RuntimeException(e); + } + + if (rows.length > maxResult) { + msg.append("\n"); + msg.append(ResultMessages.getExceedsLimitRowsMessage(maxResult, + SparkSqlInterpreter.MAX_RESULTS)); + } + + sc.clearJobGroup(); + return msg.toString(); + } + + @ZeppelinApi + public Object select(String name, scala.collection.Iterable<Tuple2<Object, String>> options) { + return select(name, "", options); + } + + @ZeppelinApi + public Object select(String name, Object defaultValue, + scala.collection.Iterable<Tuple2<Object, String>> options) { + return select(name, defaultValue, tuplesToParamOptions(options)); + } + + @ZeppelinApi + public scala.collection.Seq<Object> checkbox( + String name, + scala.collection.Iterable<Tuple2<Object, String>> options) { + List<Object> allChecked = new LinkedList<>(); + for (Tuple2<Object, String> option : asJavaIterable(options)) { + allChecked.add(option._1()); + } + return checkbox(name, collectionAsScalaIterable(allChecked), options); + } + + @ZeppelinApi + public scala.collection.Seq<Object> checkbox( + String name, + scala.collection.Iterable<Object> defaultChecked, + scala.collection.Iterable<Tuple2<Object, String>> options) { + List<Object> defaultCheckedList = Lists.newArrayList(asJavaIterable(defaultChecked).iterator()); + Collection<Object> checkbox = checkbox(name, defaultCheckedList, tuplesToParamOptions(options)); + List<Object> checkboxList = Arrays.asList(checkbox.toArray()); + return scala.collection.JavaConversions.asScalaBuffer(checkboxList).toSeq(); + } + + @ZeppelinApi + public Object noteSelect(String name, scala.collection.Iterable<Tuple2<Object, String>> options) { + return noteSelect(name, "", options); + } + + @ZeppelinApi + public Object noteSelect(String name, Object defaultValue, + scala.collection.Iterable<Tuple2<Object, String>> options) { + return noteSelect(name, defaultValue, tuplesToParamOptions(options)); + } + + @ZeppelinApi + public scala.collection.Seq<Object> noteCheckbox( + String name, + scala.collection.Iterable<Tuple2<Object, String>> options) { + List<Object> allChecked = new LinkedList<>(); + for (Tuple2<Object, String> option : asJavaIterable(options)) { + allChecked.add(option._1()); + } + return noteCheckbox(name, collectionAsScalaIterable(allChecked), options); + } + + @ZeppelinApi + public scala.collection.Seq<Object> noteCheckbox( + String name, + scala.collection.Iterable<Object> defaultChecked, + scala.collection.Iterable<Tuple2<Object, String>> options) { + List<Object> defaultCheckedList = Lists.newArrayList(asJavaIterable(defaultChecked).iterator()); + Collection<Object> checkbox = noteCheckbox(name, defaultCheckedList, + tuplesToParamOptions(options)); + List<Object> checkboxList = Arrays.asList(checkbox.toArray()); + return scala.collection.JavaConversions.asScalaBuffer(checkboxList).toSeq(); + } + + private OptionInput.ParamOption[] tuplesToParamOptions( + scala.collection.Iterable<Tuple2<Object, String>> options) { + int n = options.size(); + OptionInput.ParamOption[] paramOptions = new OptionInput.ParamOption[n]; + Iterator<Tuple2<Object, String>> it = asJavaIterable(options).iterator(); + + int i = 0; + while (it.hasNext()) { + Tuple2<Object, String> valueAndDisplayValue = it.next(); + paramOptions[i++] = new OptionInput.ParamOption(valueAndDisplayValue._1(), + valueAndDisplayValue._2()); + } + + return paramOptions; + } + + @ZeppelinApi + public void angularWatch(String name, + final scala.Function2<Object, Object, Unit> func) { + angularWatch(name, interpreterContext.getNoteId(), func); + } + + @Deprecated + public void angularWatchGlobal(String name, + final scala.Function2<Object, Object, Unit> func) { + angularWatch(name, null, func); + } + + @ZeppelinApi + public void angularWatch( + String name, + final scala.Function3<Object, Object, InterpreterContext, Unit> func) { + angularWatch(name, interpreterContext.getNoteId(), func); + } + + @Deprecated + public void angularWatchGlobal( + String name, + final scala.Function3<Object, Object, InterpreterContext, Unit> func) { + angularWatch(name, null, func); + } + + private void angularWatch(String name, String noteId, + final scala.Function2<Object, Object, Unit> func) { + AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) { + @Override + public void watch(Object oldObject, Object newObject, + InterpreterContext context) { + func.apply(newObject, newObject); + } + }; + angularWatch(name, noteId, w); + } + + private void angularWatch( + String name, + String noteId, + final scala.Function3<Object, Object, InterpreterContext, Unit> func) { + AngularObjectWatcher w = new AngularObjectWatcher(getInterpreterContext()) { + @Override + public void watch(Object oldObject, Object newObject, + InterpreterContext context) { + func.apply(oldObject, newObject, context); + } + }; + angularWatch(name, noteId, w); + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/Utils.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/Utils.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/Utils.java new file mode 100644 index 0000000..82bf210 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/Utils.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.user.AuthenticationInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Utility and helper functions for the Spark Interpreter + */ +class Utils { + public static Logger logger = LoggerFactory.getLogger(Utils.class); + private static final String SCALA_COMPILER_VERSION = evaluateScalaCompilerVersion(); + + static Object invokeMethod(Object o, String name) { + return invokeMethod(o, name, new Class[]{}, new Object[]{}); + } + + static Object invokeMethod(Object o, String name, Class<?>[] argTypes, Object[] params) { + try { + return o.getClass().getMethod(name, argTypes).invoke(o, params); + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + logger.error(e.getMessage(), e); + } + return null; + } + + static Object invokeStaticMethod(Class<?> c, String name, Class<?>[] argTypes, Object[] params) { + try { + return c.getMethod(name, argTypes).invoke(null, params); + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + logger.error(e.getMessage(), e); + } + return null; + } + + static Object invokeStaticMethod(Class<?> c, String name) { + return invokeStaticMethod(c, name, new Class[]{}, new Object[]{}); + } + + static Class<?> findClass(String name) { + return findClass(name, false); + } + + static Class<?> findClass(String name, boolean silence) { + try { + return Class.forName(name); + } catch (ClassNotFoundException e) { + if (!silence) { + logger.error(e.getMessage(), e); + } + return null; + } + } + + static Object instantiateClass(String name, Class<?>[] argTypes, Object[] params) { + try { + Constructor<?> constructor = Utils.class.getClassLoader() + .loadClass(name).getConstructor(argTypes); + return constructor.newInstance(params); + } catch (NoSuchMethodException | ClassNotFoundException | IllegalAccessException | + InstantiationException | InvocationTargetException e) { + logger.error(e.getMessage(), e); + } + return null; + } + + // function works after intp is initialized + static boolean isScala2_10() { + try { + Class.forName("org.apache.spark.repl.SparkIMain"); + return true; + } catch (ClassNotFoundException e) { + return false; + } catch (IncompatibleClassChangeError e) { + return false; + } + } + + static boolean isScala2_11() { + return !isScala2_10(); + } + + static boolean isCompilerAboveScala2_11_7() { + if (isScala2_10() || SCALA_COMPILER_VERSION == null) { + return false; + } + Pattern p = Pattern.compile("([0-9]+)[.]([0-9]+)[.]([0-9]+)"); + Matcher m = p.matcher(SCALA_COMPILER_VERSION); + if (m.matches()) { + int major = Integer.parseInt(m.group(1)); + int minor = Integer.parseInt(m.group(2)); + int bugfix = Integer.parseInt(m.group(3)); + return (major > 2 || (major == 2 && minor > 11) || (major == 2 && minor == 11 && bugfix > 7)); + } + return false; + } + + private static String evaluateScalaCompilerVersion() { + String version = null; + try { + Properties p = new Properties(); + Class<?> completionClass = findClass("scala.tools.nsc.interpreter.JLineCompletion"); + if (completionClass != null) { + try (java.io.InputStream in = completionClass.getClass() + .getResourceAsStream("/compiler.properties")) { + p.load(in); + version = p.getProperty("version.number"); + } catch (java.io.IOException e) { + logger.error("Failed to evaluate Scala compiler version", e); + } + } + } catch (RuntimeException e) { + logger.error("Failed to evaluate Scala compiler version", e); + } + return version; + } + + static boolean isSpark2() { + try { + Class.forName("org.apache.spark.sql.SparkSession"); + return true; + } catch (ClassNotFoundException e) { + return false; + } + } + + public static String buildJobGroupId(InterpreterContext context) { + return "zeppelin-" + context.getNoteId() + "-" + context.getParagraphId(); + } + + public static String getNoteId(String jobgroupId) { + int indexOf = jobgroupId.indexOf("-"); + int secondIndex = jobgroupId.indexOf("-", indexOf + 1); + return jobgroupId.substring(indexOf + 1, secondIndex); + } + + public static String getParagraphId(String jobgroupId) { + int indexOf = jobgroupId.indexOf("-"); + int secondIndex = jobgroupId.indexOf("-", indexOf + 1); + return jobgroupId.substring(secondIndex + 1, jobgroupId.length()); + } + + public static String getUserName(AuthenticationInfo info) { + String uName = ""; + if (info != null) { + uName = info.getUser(); + } + if (uName == null || uName.isEmpty()) { + uName = "anonymous"; + } + return uName; + } +} http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinR.java ---------------------------------------------------------------------- diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinR.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinR.java new file mode 100644 index 0000000..130d849 --- /dev/null +++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinR.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.spark; + +import org.apache.commons.exec.*; +import org.apache.commons.exec.environment.EnvironmentUtils; +import org.apache.commons.io.IOUtils; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterOutputListener; +import org.apache.zeppelin.interpreter.InterpreterResultMessageOutput; +import org.apache.zeppelin.interpreter.util.InterpreterOutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.*; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * R repl interaction + */ +public class ZeppelinR implements ExecuteResultHandler { + Logger logger = LoggerFactory.getLogger(ZeppelinR.class); + private final String rCmdPath; + private final SparkVersion sparkVersion; + private DefaultExecutor executor; + private InterpreterOutputStream outputStream; + private PipedOutputStream input; + private final String scriptPath; + private final String libPath; + static Map<Integer, ZeppelinR> zeppelinR = Collections.synchronizedMap( + new HashMap<Integer, ZeppelinR>()); + + private InterpreterOutput initialOutput; + private final int port; + private boolean rScriptRunning; + + /** + * To be notified R repl initialization + */ + boolean rScriptInitialized = false; + Integer rScriptInitializeNotifier = new Integer(0); + + /** + * Request to R repl + */ + Request rRequestObject = null; + Integer rRequestNotifier = new Integer(0); + + /** + * Request object + * + * type : "eval", "set", "get" + * stmt : statement to evaluate when type is "eval" + * key when type is "set" or "get" + * value : value object when type is "put" + */ + public static class Request { + String type; + String stmt; + Object value; + + public Request(String type, String stmt, Object value) { + this.type = type; + this.stmt = stmt; + this.value = value; + } + + public String getType() { + return type; + } + + public String getStmt() { + return stmt; + } + + public Object getValue() { + return value; + } + } + + /** + * Response from R repl + */ + Object rResponseValue = null; + boolean rResponseError = false; + Integer rResponseNotifier = new Integer(0); + + /** + * Create ZeppelinR instance + * @param rCmdPath R repl commandline path + * @param libPath sparkr library path + */ + public ZeppelinR(String rCmdPath, String libPath, int sparkRBackendPort, + SparkVersion sparkVersion) { + this.rCmdPath = rCmdPath; + this.libPath = libPath; + this.sparkVersion = sparkVersion; + this.port = sparkRBackendPort; + try { + File scriptFile = File.createTempFile("zeppelin_sparkr-", ".R"); + scriptPath = scriptFile.getAbsolutePath(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * Start R repl + * @throws IOException + */ + public void open() throws IOException, InterpreterException { + createRScript(); + + zeppelinR.put(hashCode(), this); + + CommandLine cmd = CommandLine.parse(rCmdPath); + cmd.addArgument("--no-save"); + cmd.addArgument("--no-restore"); + cmd.addArgument("-f"); + cmd.addArgument(scriptPath); + cmd.addArgument("--args"); + cmd.addArgument(Integer.toString(hashCode())); + cmd.addArgument(Integer.toString(port)); + cmd.addArgument(libPath); + cmd.addArgument(Integer.toString(sparkVersion.toNumber())); + + // dump out the R command to facilitate manually running it, e.g. for fault diagnosis purposes + logger.debug(cmd.toString()); + + executor = new DefaultExecutor(); + outputStream = new InterpreterOutputStream(logger); + + input = new PipedOutputStream(); + PipedInputStream in = new PipedInputStream(input); + + PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, outputStream, in); + executor.setWatchdog(new ExecuteWatchdog(ExecuteWatchdog.INFINITE_TIMEOUT)); + executor.setStreamHandler(streamHandler); + Map env = EnvironmentUtils.getProcEnvironment(); + + + initialOutput = new InterpreterOutput(null); + outputStream.setInterpreterOutput(initialOutput); + executor.execute(cmd, env, this); + rScriptRunning = true; + + // flush output + eval("cat('')"); + } + + /** + * Evaluate expression + * @param expr + * @return + */ + public Object eval(String expr) throws InterpreterException { + synchronized (this) { + rRequestObject = new Request("eval", expr, null); + return request(); + } + } + + /** + * assign value to key + * @param key + * @param value + */ + public void set(String key, Object value) throws InterpreterException { + synchronized (this) { + rRequestObject = new Request("set", key, value); + request(); + } + } + + /** + * get value of key + * @param key + * @return + */ + public Object get(String key) throws InterpreterException { + synchronized (this) { + rRequestObject = new Request("get", key, null); + return request(); + } + } + + /** + * get value of key, as a string + * @param key + * @return + */ + public String getS0(String key) throws InterpreterException { + synchronized (this) { + rRequestObject = new Request("getS", key, null); + return (String) request(); + } + } + + /** + * Send request to r repl and return response + * @return responseValue + */ + private Object request() throws RuntimeException, InterpreterException { + if (!rScriptRunning) { + throw new RuntimeException("r repl is not running"); + } + + // wait for rscript initialized + if (!rScriptInitialized) { + waitForRScriptInitialized(); + } + + rResponseValue = null; + + synchronized (rRequestNotifier) { + rRequestNotifier.notify(); + } + + Object respValue = null; + synchronized (rResponseNotifier) { + while (rResponseValue == null && rScriptRunning) { + try { + rResponseNotifier.wait(1000); + } catch (InterruptedException e) { + logger.error(e.getMessage(), e); + } + } + respValue = rResponseValue; + rResponseValue = null; + } + + if (rResponseError) { + throw new RuntimeException(respValue.toString()); + } else { + return respValue; + } + } + + /** + * Wait until src/main/resources/R/zeppelin_sparkr.R is initialized + * and call onScriptInitialized() + * + * @throws InterpreterException + */ + private void waitForRScriptInitialized() throws InterpreterException { + synchronized (rScriptInitializeNotifier) { + long startTime = System.nanoTime(); + while (rScriptInitialized == false && + rScriptRunning && + System.nanoTime() - startTime < 10L * 1000 * 1000000) { + try { + rScriptInitializeNotifier.wait(1000); + } catch (InterruptedException e) { + logger.error(e.getMessage(), e); + } + } + } + + String errorMessage = ""; + try { + initialOutput.flush(); + errorMessage = new String(initialOutput.toByteArray()); + } catch (IOException e) { + e.printStackTrace(); + } + + if (rScriptInitialized == false) { + throw new InterpreterException("sparkr is not responding " + errorMessage); + } + } + + /** + * invoked by src/main/resources/R/zeppelin_sparkr.R + * @return + */ + public Request getRequest() { + synchronized (rRequestNotifier) { + while (rRequestObject == null) { + try { + rRequestNotifier.wait(1000); + } catch (InterruptedException e) { + logger.error(e.getMessage(), e); + } + } + + Request req = rRequestObject; + rRequestObject = null; + return req; + } + } + + /** + * invoked by src/main/resources/R/zeppelin_sparkr.R + * @param value + * @param error + */ + public void setResponse(Object value, boolean error) { + synchronized (rResponseNotifier) { + rResponseValue = value; + rResponseError = error; + rResponseNotifier.notify(); + } + } + + /** + * invoked by src/main/resources/R/zeppelin_sparkr.R + */ + public void onScriptInitialized() { + synchronized (rScriptInitializeNotifier) { + rScriptInitialized = true; + rScriptInitializeNotifier.notifyAll(); + } + } + + /** + * Create R script in tmp dir + */ + private void createRScript() throws InterpreterException { + ClassLoader classLoader = getClass().getClassLoader(); + File out = new File(scriptPath); + + if (out.exists() && out.isDirectory()) { + throw new InterpreterException("Can't create r script " + out.getAbsolutePath()); + } + + try { + FileOutputStream outStream = new FileOutputStream(out); + IOUtils.copy( + classLoader.getResourceAsStream("R/zeppelin_sparkr.R"), + outStream); + outStream.close(); + } catch (IOException e) { + throw new InterpreterException(e); + } + + logger.info("File {} created", scriptPath); + } + + /** + * Terminate this R repl + */ + public void close() { + executor.getWatchdog().destroyProcess(); + new File(scriptPath).delete(); + zeppelinR.remove(hashCode()); + } + + /** + * Get instance + * This method will be invoded from zeppelin_sparkr.R + * @param hashcode + * @return + */ + public static ZeppelinR getZeppelinR(int hashcode) { + return zeppelinR.get(hashcode); + } + + /** + * Pass InterpreterOutput to capture the repl output + * @param out + */ + public void setInterpreterOutput(InterpreterOutput out) { + outputStream.setInterpreterOutput(out); + } + + @Override + public void onProcessComplete(int i) { + logger.info("process complete {}", i); + rScriptRunning = false; + } + + @Override + public void onProcessFailed(ExecuteException e) { + logger.error(e.getMessage(), e); + rScriptRunning = false; + } +}